From 65079b8de631dcd1bb7271493627581678bd922c Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Mon, 4 Dec 2023 12:09:39 -0800 Subject: [PATCH] client breaking: add initial version of client handlers --- client/Cargo.toml | 13 ++- client/examples/cli.rs | 181 +++++++++++++++++++++++++++++++++ client/src/client.rs | 54 +++++++--- client/src/client_handler.rs | 93 +++++++++++++++++ client/src/client_like.rs | 5 +- client/src/conn.rs | 49 ++++++--- client/src/cookies.rs | 46 +++++++++ client/src/follow_redirects.rs | 53 ++++++++++ client/src/lib.rs | 15 ++- server-common/src/client.rs | 12 +++ 10 files changed, 488 insertions(+), 33 deletions(-) create mode 100644 client/examples/cli.rs create mode 100644 client/src/client_handler.rs create mode 100644 client/src/cookies.rs create mode 100644 client/src/follow_redirects.rs diff --git a/client/Cargo.toml b/client/Cargo.toml index 01c3bb6cdf..d64c6c9ef4 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -11,7 +11,8 @@ keywords = ["trillium", "framework", "async"] categories = ["web-programming", "web-programming::http-client"] [features] -json = ["serde_json", "serde", "thiserror"] +cookies = ["dep:cookie_store", "dep:async-lock"] +json = ["dep:serde_json", "dep:serde", "dep:thiserror"] [dependencies] encoding_rs = "0.8.33" @@ -28,6 +29,9 @@ thiserror = { version = "1.0.50", optional = true } dashmap = "5.5.3" crossbeam-queue = "0.3.8" memchr = "2.6.4" +async-lock = {version = "3.2.0", optional = true } +arc-swap = "1.6.0" +cookie_store = { version = "0.20.0", optional = true } [dependencies.trillium-http] path = "../http" @@ -36,12 +40,19 @@ version = "0.3.6" [dev-dependencies] async-channel = "2.1.0" +async-fs = "2.1.0" async-global-executor = "2.3.1" +blocking = "1.5.1" +clap = { version = "4.4.10", features = ["derive", "env"] } +clap-verbosity-flag = "2.1.0" +crossbeam = "0.8.2" env_logger = "0.10.1" indoc = "2.0.4" pretty_assertions = "1.4.0" test-harness = "0.1.1" trillium = { path = "../trillium" } +trillium-native-tls = { path = "../native-tls" } +trillium-rustls = { path = "../rustls" } trillium-smol = { path = "../smol/" } trillium-testing = { path = "../testing" } diff --git a/client/examples/cli.rs b/client/examples/cli.rs new file mode 100644 index 0000000000..6269573491 --- /dev/null +++ b/client/examples/cli.rs @@ -0,0 +1,181 @@ +use blocking::Unblock; +use clap::Parser; +use std::{ + io::{ErrorKind, IsTerminal}, + path::PathBuf, + str::FromStr, +}; +use trillium::{Body, Method}; +use trillium_client::{Client, Conn, Error, FollowRedirects}; +use trillium_native_tls::NativeTlsConfig; +use trillium_rustls::RustlsConfig; +use trillium_smol::ClientConfig; +use url::{self, Url}; + +pub fn main() { + ClientCli::parse().run() +} + +#[derive(Parser, Debug)] +pub struct ClientCli { + #[arg(value_parser = parse_method_case_insensitive)] + method: Method, + + #[arg(value_parser = parse_url)] + url: Url, + + /// provide a file system path to a file to use as the request body + /// + /// alternatively, you can use an operating system pipe to pass a file in + /// + /// three equivalent examples: + /// + /// trillium client post http://httpbin.org/anything -f ./body.json + /// trillium client post http://httpbin.org/anything < ./body.json + /// cat ./body.json | trillium client post http://httpbin.org/anything + #[arg(short, long, verbatim_doc_comment)] + file: Option, + + /// provide a request body on the command line + /// + /// example: + /// trillium client post http://httpbin.org/post -b '{"hello": "world"}' + #[arg(short, long, verbatim_doc_comment)] + body: Option, + + /// provide headers in the form -h KEY1=VALUE1 KEY2=VALUE2 + /// + /// example: + /// trillium client get http://httpbin.org/headers -H Accept=application/json Authorization="Basic u:p" + #[arg(short = 'H', long, value_parser = parse_header, verbatim_doc_comment)] + headers: Vec<(String, String)>, + + /// tls implementation. options: rustls, native-tls, none + /// + /// requests to https:// urls with `none` will fail + #[arg(short, long, default_value = "rustls", verbatim_doc_comment)] + tls: TlsType, + + /// set the log level. add more flags for more verbosity + /// + /// example: + /// trillium client get https://www.google.com -vvv # `trace` verbosity level + #[command(flatten)] + verbose: clap_verbosity_flag::Verbosity, +} + +impl ClientCli { + async fn build(&self) -> Conn { + let mut client = match self.tls { + TlsType::None => Client::new(ClientConfig::default()), + TlsType::Rustls => Client::new(RustlsConfig::::default()), + TlsType::NativeTls => Client::new(NativeTlsConfig::::default()), + }; + + client.set_handler(FollowRedirects::new()); + + let mut conn = client.build_conn(self.method, self.url.clone()); + for (name, value) in &self.headers { + conn.request_headers().append(name.clone(), value.clone()); + } + + if let Some(path) = &self.file { + let metadata = async_fs::metadata(path) + .await + .unwrap_or_else(|e| panic!("could not read file {:?} ({})", path, e)); + + let file = async_fs::File::open(path) + .await + .unwrap_or_else(|e| panic!("could not read file {:?} ({})", path, e)); + + conn.with_body(Body::new_streaming(file, Some(metadata.len()))) + } else if let Some(body) = &self.body { + conn.with_body(body.clone()) + } else if !std::io::stdin().is_terminal() { + conn.with_body(Body::new_streaming(Unblock::new(std::io::stdin()), None)) + } else { + conn + } + } + + pub fn run(self) { + trillium_smol::async_global_executor::block_on(async move { + env_logger::Builder::new() + .filter_level(self.verbose.log_level_filter()) + .init(); + + let mut conn = self.build().await; + + if let Err(e) = (&mut conn).await { + match e { + Error::Io(io) if io.kind() == ErrorKind::ConnectionRefused => { + log::error!("could not reach {}", self.url) + } + + _ => log::error!("protocol error:\n\n{}", e), + } + + return; + } + + if std::io::stdout().is_terminal() { + let body = conn.response_body().read_string().await.unwrap(); + + let _request_headers_as_string = format!("{:#?}", conn.request_headers()); + let headers = conn.response_headers(); + let _response_headers_as_string = format!("{:#?}", headers); + let _status_string = conn.status().unwrap().to_string(); + println!("{conn:#?}"); + println!("{body}"); + } else { + futures_lite::io::copy( + &mut conn.response_body(), + &mut Unblock::new(std::io::stdout()), + ) + .await + .unwrap(); + } + }); + } +} + +#[derive(clap::ValueEnum, Debug, Eq, PartialEq, Clone)] +enum TlsType { + None, + Rustls, + NativeTls, +} + +fn parse_method_case_insensitive(src: &str) -> Result { + src.to_uppercase() + .parse() + .map_err(|_| format!("unrecognized method {}", src)) +} + +fn parse_url(src: &str) -> Result { + if src.starts_with("http") { + src.parse() + } else { + format!("http://{}", src).parse() + } +} + +impl FromStr for TlsType { + type Err = String; + + fn from_str(s: &str) -> Result { + match &*s.to_ascii_lowercase() { + "none" => Ok(Self::None), + "rustls" => Ok(Self::Rustls), + "native" | "native-tls" => Ok(Self::NativeTls), + _ => Err(format!("unrecognized tls {}", s)), + } + } +} + +fn parse_header(s: &str) -> Result<(String, String), String> { + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{}`", s))?; + Ok((String::from(&s[..pos]), String::from(&s[pos + 1..]))) +} diff --git a/client/src/client.rs b/client/src/client.rs index b58c4c9e82..2eacf59db6 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -1,4 +1,6 @@ -use crate::{ClientLike, Conn, Pool}; +use crate::{client_handler::ClientHandler, ClientLike, Conn, Pool}; +use arc_swap::ArcSwapOption; + use std::{convert::TryInto, fmt::Debug, sync::Arc}; use trillium_http::{transport::BoxedTransport, Method}; use trillium_server_common::{Connector, ObjectSafeConnector, Url}; @@ -9,10 +11,15 @@ A client contains a Config and an optional connection pool and builds conns. */ + #[derive(Clone, Debug)] -pub struct Client { - config: Arc, - pool: Option>, +pub struct Client(Arc); + +#[derive(Debug)] +pub struct ClientInner { + config: Box, + pool: ArcSwapOption>, + handler: ArcSwapOption>, } macro_rules! method { @@ -59,10 +66,11 @@ assert_eq!(conn.url().to_string(), \"http://localhost:8080/some/route\"); impl Client { /// builds a new client from this `Connector` pub fn new(config: impl Connector) -> Self { - Self { - config: config.arced(), - pool: None, - } + Self(Arc::new(ClientInner { + config: config.boxed(), + pool: ArcSwapOption::empty(), + handler: ArcSwapOption::empty(), + })) } /** @@ -78,8 +86,8 @@ impl Client { .with_default_pool(); //<- ``` */ - pub fn with_default_pool(mut self) -> Self { - self.pool = Some(Pool::default()); + pub fn with_default_pool(self) -> Self { + self.0.pool.store(Some(Arc::new(Pool::default()))); self } @@ -109,15 +117,16 @@ impl Client { U: TryInto, >::Error: Debug, { - let mut conn = Conn::new_with_config( - Arc::clone(&self.config), + let mut conn = Conn::new_with_client( + self.clone(), method.try_into().unwrap(), url.try_into().unwrap(), ); - if let Some(pool) = &self.pool { + if let Some(pool) = self.0.pool.load_full().as_deref() { conn.set_pool(pool.clone()); } + conn } @@ -128,7 +137,7 @@ impl Client { intermittently. */ pub fn clean_up_pool(&self) { - if let Some(pool) = &self.pool { + if let Some(pool) = &*self.0.pool.load() { pool.cleanup(); } } @@ -138,6 +147,23 @@ impl Client { method!(put, Put); method!(delete, Delete); method!(patch, Patch); + + pub(crate) fn handler(&self) -> Option>> { + self.0.handler.load_full() + } + /// + pub fn with_handler(mut self, handler: impl ClientHandler) -> Self { + self.set_handler(handler); + self + } + /// + pub fn set_handler(&mut self, handler: impl ClientHandler) { + self.0.handler.store(Some(Arc::new(Box::new(handler)))) + } + + pub(crate) fn connector(&self) -> &dyn ObjectSafeConnector { + &self.0.config + } } impl From for Client { diff --git a/client/src/client_handler.rs b/client/src/client_handler.rs new file mode 100644 index 0000000000..c1d74adb89 --- /dev/null +++ b/client/src/client_handler.rs @@ -0,0 +1,93 @@ +use std::borrow::Cow; + +use crate::Conn; +use trillium_server_common::async_trait; + +#[async_trait] +pub trait ClientHandler: std::fmt::Debug + Send + Sync + 'static { + async fn before(&self, conn: &mut Conn) -> crate::Result<()> { + let _ = conn; + Ok(()) + } + + async fn after(&self, conn: &mut Conn) -> crate::Result<()> { + let _ = conn; + Ok(()) + } + + fn name(&self) -> Cow<'static, str> { + std::any::type_name::().into() + } +} + +impl ClientHandler for () {} + +#[async_trait] +impl ClientHandler for Option { + async fn before(&self, conn: &mut Conn) -> crate::Result<()> { + match self { + Some(h) => h.before(conn).await, + None => Ok(()), + } + } + + async fn after(&self, conn: &mut Conn) -> crate::Result<()> { + match self { + Some(h) => h.after(conn).await, + None => Ok(()), + } + } + + fn name(&self) -> Cow<'static, str> { + match self { + Some(h) => h.name(), + None => "None".into(), + } + } +} + +macro_rules! impl_handler_tuple { + ($($name:ident)+) => ( + #[async_trait] + impl<$($name),*> ClientHandler for ($($name,)*) where $($name: ClientHandler),* { + #[allow(non_snake_case)] + async fn before(&self, conn: &mut Conn) -> crate::Result<()> { + let ($(ref $name,)*) = *self; + $( + log::trace!("running {}", ($name).name()); + ($name).before(conn).await?; + )* + Ok(()) + } + #[allow(non_snake_case)] + async fn after(&self, conn: &mut Conn) -> crate::Result<()> { + let ($(ref $name,)*) = *self; + $( + log::trace!("running {}", ($name).name()); + ($name).after(conn).await?; + )* + Ok(()) + } + + #[allow(non_snake_case)] + fn name(&self) -> Cow<'static, str> { + let ($(ref $name,)*) = *self; + format!(concat!("(\n", $( + concat!(" {",stringify!($name) ,":},\n") + ),*, ")"), $($name = ($name).name()),*).into() + } + } + ); +} +impl_handler_tuple! { A } +impl_handler_tuple! { A B } +impl_handler_tuple! { A B C } +impl_handler_tuple! { A B C D } +impl_handler_tuple! { A B C D E } +impl_handler_tuple! { A B C D E F } +impl_handler_tuple! { A B C D E F G } +impl_handler_tuple! { A B C D E F G H } +impl_handler_tuple! { A B C D E F G H I } +impl_handler_tuple! { A B C D E F G H I J } +impl_handler_tuple! { A B C D E F G H I J K } +impl_handler_tuple! { A B C D E F G H I J K L } diff --git a/client/src/client_like.rs b/client/src/client_like.rs index c963fd52d3..2a98be7e99 100644 --- a/client/src/client_like.rs +++ b/client/src/client_like.rs @@ -1,4 +1,4 @@ -use crate::Conn; +use crate::{Client, Conn}; use trillium_http::Method; use trillium_server_common::{Connector, ObjectSafeConnector, Url}; @@ -41,6 +41,7 @@ pub trait ClientLike { impl ClientLike for C { fn build_conn(&self, method: Method, url: Url) -> Conn { - Conn::new_with_config(self.clone().arced(), method, url) + let client = Client::new(self.clone().arced()); + Conn::new_with_client(client, method, url) } } diff --git a/client/src/conn.rs b/client/src/conn.rs index 7790a2adf6..62384ebec1 100644 --- a/client/src/conn.rs +++ b/client/src/conn.rs @@ -1,4 +1,4 @@ -use crate::{pool::PoolEntry, util::encoding, Pool}; +use crate::{client_handler::ClientHandler, pool::PoolEntry, util::encoding, Client, Pool}; use encoding_rs::Encoding; use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt}; use memchr::memmem::Finder; @@ -10,7 +10,6 @@ use std::{ ops::{Deref, DerefMut}, pin::Pin, str::FromStr, - sync::Arc, }; use trillium_http::{ transport::BoxedTransport, @@ -20,7 +19,7 @@ use trillium_http::{ }, Method, ReceivedBody, ReceivedBodyState, Result, StateSet, Status, Stopper, Upgrade, }; -use trillium_server_common::{Connector, ObjectSafeConnector, Transport}; +use trillium_server_common::{Connector, Transport}; use url::{Origin, Url}; const MAX_HEADERS: usize = 128; @@ -59,8 +58,9 @@ pub struct Conn { pool: Option>, buffer: trillium_http::Buffer, response_body_state: ReceivedBodyState, - config: Arc, headers_finalized: bool, + client: Client, + state: StateSet, } /// default http user-agent header @@ -78,7 +78,7 @@ impl Debug for Conn { .field("pool", &self.pool) .field("buffer", &String::from_utf8_lossy(&self.buffer)) .field("response_body_state", &self.response_body_state) - .field("config", &self.config) + .field("client", &self.client) .finish() } } @@ -132,11 +132,17 @@ impl Conn { // ) // } - pub(crate) fn new_with_config( - config: Arc, - method: Method, - url: Url, - ) -> Self { + /// document + pub fn state(&self) -> &StateSet { + &self.state + } + + /// document + pub fn state_mut(&mut self) -> &mut StateSet { + &mut self.state + } + + pub(crate) fn new_with_client(client: Client, method: Method, url: Url) -> Self { Self { url, method, @@ -148,11 +154,17 @@ impl Conn { pool: None, buffer: Vec::with_capacity(128).into(), response_body_state: ReceivedBodyState::Start, - config, + client, headers_finalized: false, + state: StateSet::new(), } } + ///document + pub fn client(&self) -> &Client { + &self.client + } + /** retrieves a mutable borrow of the request headers, suitable for appending a header. generally, prefer using chainable methods on @@ -197,7 +209,7 @@ impl Conn { use trillium_testing::ClientConfig; - let handler = |conn: trillium::Conn| async move { + let handler= |conn: trillium::Conn| async move { let header = conn.headers().get_str("some-request-header").unwrap_or_default(); let response = format!("some-request-header was {}", header); conn.ok(response) @@ -633,7 +645,7 @@ impl Conn { } None => { - let mut transport = Connector::connect(&self.config, &self.url).await?; + let mut transport = Connector::connect(self.client.connector(), &self.url).await?; log::debug!("opened new connection to {:?}", transport.peer_addr()?); transport.write_all(&head).await?; transport @@ -820,10 +832,17 @@ impl Conn { } } - async fn exec(&mut self) -> Result<()> { + pub async fn exec(&mut self) -> Result<()> { + let handler = self.client.handler(); + if let Some(ref handler) = handler { + handler.before(self).await?; + } self.finalize_headers(); self.connect_and_send_head().await?; self.send_body_and_parse_head().await?; + if let Some(handler) = handler { + handler.after(self).await?; + } Ok(()) } } @@ -860,7 +879,7 @@ impl Drop for Conn { let buffer = std::mem::take(&mut self.buffer); let response_body_state = self.response_body_state; let encoding = encoding(&self.response_headers); - Connector::spawn(&self.config, async move { + Connector::spawn(self.client.connector(), async move { let mut response_body = ReceivedBody::new( content_length, buffer, diff --git a/client/src/cookies.rs b/client/src/cookies.rs new file mode 100644 index 0000000000..f126298d66 --- /dev/null +++ b/client/src/cookies.rs @@ -0,0 +1,46 @@ +use crate::{client_handler::ClientHandler, Conn, KnownHeaderName}; +use async_lock::RwLock; +use cookie_store::CookieStore; +use std::sync::Arc; + +pub struct Cookies { + store: RwLock, +} + +impl Cookies { + pub fn new() -> Self { + Self { + store: Arc::new(RwLock::new(CookieStore::new(None))), + } + } +} + +#[crate::async_trait] +impl ClientHandler for Cookies { + async fn before(&self, conn: &mut Conn) -> crate::Result<()> { + let mut matches = self.store.read().await.matches(conn.url()); + matches.sort_by(|a, b| b.path.len().cmp(&a.path.len())); + let values = matches + .iter() + .map(|cookie| format!("{}={}", cookie.name(), cookie.value())) + .collect::>() + .join("; "); + conn.request_headers() + .insert(KnownHeaderName::Cookie, values); + Ok(()) + } + + async fn after(&self, conn: &mut Conn) -> crate::Result<()> { + if let Some(set_cookies) = conn.response_headers().get_all(KnownHeaderName::SetCookie) { + let mut cookie_store = self.store.write().await; + for cookie in set_cookies { + match cookie_store.parse(cookie.as_str(), request_url) { + Ok(action) => log::trace!("cookie action: {:?}", action), + Err(e) => log::trace!("cookie parse error: {:?}", e), + } + } + } + + Ok(()) + } +} diff --git a/client/src/follow_redirects.rs b/client/src/follow_redirects.rs new file mode 100644 index 0000000000..675020bdb1 --- /dev/null +++ b/client/src/follow_redirects.rs @@ -0,0 +1,53 @@ +use std::mem; + +use crate::{async_trait, client_handler::ClientHandler, Conn, KnownHeaderName, Result}; + +#[derive(Debug, Default, Copy, Clone)] +pub struct FollowRedirects { + _private: (), +} + +impl FollowRedirects { + pub fn new() -> Self { + Self { _private: () } + } +} + +#[derive(Default, Debug)] +pub struct RedirectHistory(Vec); + +#[async_trait] +impl ClientHandler for FollowRedirects { + async fn after(&self, conn: &mut Conn) -> Result<()> { + let client = conn.client().clone(); + + if !matches!(conn.status(), Some(status) if status.is_redirection()) + || !conn.method().is_safe() + { + return Ok(()); + } + + let Some(location) = conn.response_headers().get_str(KnownHeaderName::Location) else { + return Ok(()); + }; + + let mut new_conn = client.build_conn(conn.method(), location); + new_conn.request_headers().append_all( + conn.request_headers() + .clone() + .without_header(KnownHeaderName::Host), + ); + + mem::swap(new_conn.state_mut(), conn.state_mut()); + let old_conn = mem::replace(conn, new_conn); + + conn.state_mut() + .get_or_insert_with(RedirectHistory::default) + .0 + .push(old_conn); + + conn.exec().await?; + + Ok(()) + } +} diff --git a/client/src/lib.rs b/client/src/lib.rs index 02b1d1b046..ba4427f585 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -4,10 +4,10 @@ missing_copy_implementations, rustdoc::missing_crate_level_docs, missing_debug_implementations, - missing_docs, nonstandard_style, unused_qualifications )] +#![warn(missing_docs)] /*! trillium client is a http client that uses the same `conn` approach as @@ -30,6 +30,7 @@ examples. */ +mod client_handler; mod conn; pub use conn::{Conn, UnexpectedStatusError, USER_AGENT}; @@ -58,3 +59,15 @@ pub use client_like::ClientLike; pub fn client(connector: impl Connector) -> Client { Client::new(connector) } + +pub use client_handler::ClientHandler; + +#[cfg(feature = "cookies")] +mod cookies; +#[cfg(feature = "cookies")] +pub use cookies::Cookies; + +mod follow_redirects; +pub use follow_redirects::{FollowRedirects, RedirectHistory}; + +pub use trillium_server_common::async_trait; diff --git a/server-common/src/client.rs b/server-common/src/client.rs index e70b4d1dec..72b3c8fe26 100644 --- a/server-common/src/client.rs +++ b/server-common/src/client.rs @@ -80,6 +80,18 @@ impl Connector for Box { } } +#[async_trait] +impl Connector for dyn ObjectSafeConnector { + type Transport = BoxedTransport; + async fn connect(&self, url: &Url) -> Result { + ObjectSafeConnector::connect(self, url).await + } + + fn spawn + Send + 'static>(&self, fut: Fut) { + ObjectSafeConnector::spawn(self, Box::pin(fut)) + } +} + #[async_trait] impl Connector for Arc { type Transport = BoxedTransport;