From b20149d81f08e1369cab8a2fe94422382469c1e4 Mon Sep 17 00:00:00 2001 From: Gregory Hill Date: Tue, 13 Oct 2020 10:30:03 +0100 Subject: [PATCH] conditional compilation for async roundtripper Signed-off-by: Gregory Hill --- Cargo.toml | 4 + src/client.rs | 231 +++++++++++++++++++++++++++++++++++++++++++------- src/error.rs | 2 +- src/util.rs | 8 +- 4 files changed, 207 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c54572bf..d9e204a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,10 @@ documentation = "https://docs.rs/jsonrpc/" description = "Rust support for the JSON-RPC 2.0 protocol" keywords = [ "protocol", "json", "http", "jsonrpc" ] readme = "README.md" +edition = "2018" + +[features] +async = [] [lib] name = "jsonrpc" diff --git a/src/client.rs b/src/client.rs index cb784aab..68c4fd55 100644 --- a/src/client.rs +++ b/src/client.rs @@ -18,18 +18,18 @@ //! and parsing responses //! -use std::{error, io}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use std::{error, io}; -use serde; use base64; use http; +use serde; use serde_json; use super::{Request, Response}; -use util::HashableValue; -use error::Error; +use crate::error::Error; +use crate::util::HashableValue; /// An interface for an HTTP roundtripper that handles HTTP requests. pub trait HttpRoundTripper { @@ -38,15 +38,38 @@ pub trait HttpRoundTripper { /// The type for errors generated by the roundtripper. type Err: error::Error; - /// Make an HTTP request. In practice only POST request will be made. + /// Make a synchronous HTTP request. In practice only POST request will be made. fn request( &self, - http::Request<&[u8]>, + _request: http::Request<&[u8]>, ) -> Result, Self::Err>; } +/// An interface for an asynchronous HTTP roundtripper that handles HTTP requests. +#[cfg(feature = "async")] +pub trait AsyncHttpRoundTripper { + /// The type of the http::Response body. + type ResponseBody: io::Read; + /// The type for errors generated by the roundtripper. + type Err: error::Error; + + /// Make an asynchronous HTTP request. In practice only POST request will be made. + fn request<'life>( + &'life self, + _request: http::Request<&'life [u8]>, + ) -> std::pin::Pin< + Box< + dyn std::future::Future, Self::Err>> + + Send + + 'life, + >, + > + where + Self: Sync + 'life; +} + /// A handle to a remote JSONRPC server -pub struct Client { +pub struct Client { url: String, user: Option, pass: Option, @@ -54,14 +77,14 @@ pub struct Client { nonce: Arc>, } -impl Client { +impl Client { /// Creates a new client pub fn new( - roundtripper: Rt, + roundtripper: R, url: String, user: Option, pass: Option, - ) -> Client { + ) -> Client { // Check that if we have a password, we have a username; other way around is ok debug_assert!(pass.is_none() || user.is_some()); @@ -74,6 +97,29 @@ impl Client { } } + /// Builds a request + pub fn build_request<'a, 'b>( + &self, + name: &'a str, + params: &'b [serde_json::Value], + ) -> Request<'a, 'b> { + let mut nonce = self.nonce.lock().unwrap(); + *nonce += 1; + Request { + method: name, + params: params, + id: From::from(*nonce), + jsonrpc: Some("2.0"), + } + } + + /// Accessor for the last-used nonce + pub fn last_nonce(&self) -> u64 { + *self.nonce.lock().unwrap() + } +} + +impl Client { /// Make a request and deserialize the response pub fn do_rpc serde::de::Deserialize<'a>>( &self, @@ -82,7 +128,6 @@ impl Client { ) -> Result { let request = self.build_request(rpc_name, args); let response = self.send_request(&request)?; - Ok(response.into_result()?) } @@ -113,8 +158,10 @@ impl Client { // Errors only on invalid header or builder reuse. let http_request = request_builder.body(&request_raw[..]).unwrap(); - let http_response = - self.roundtripper.request(http_request).map_err(|e| Error::Http(Box::new(e)))?; + let http_response = self + .roundtripper + .request(http_request) + .map_err(|e| Error::Http(Box::new(e)))?; // nb we ignore stream.status since we expect the body // to contain information about any error @@ -122,8 +169,9 @@ impl Client { } /// Sends a request to a client - pub fn send_request(&self, request: &Request) -> Result { + pub fn send_request<'a, 'b>(&self, request: &Request<'a, 'b>) -> Result { let response: Response = self.send_raw(&request)?; + if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) { return Err(Error::VersionMismatch); } @@ -138,7 +186,10 @@ impl Client { /// /// Note that the requests need to have valid IDs, so it is advised to create the requests /// with [build_request]. - pub fn send_batch(&self, requests: &[Request]) -> Result>, Error> { + pub fn send_batch<'a, 'b>( + &self, + requests: &[Request<'a, 'b>], + ) -> Result>, Error> { if requests.len() < 1 { return Err(Error::EmptyBatch); } @@ -146,6 +197,7 @@ impl Client { // If the request body is invalid JSON, the response is a single response object. // We ignore this case since we are confident we are producing valid JSON. let responses: Vec = self.send_raw(&requests)?; + if responses.len() > requests.len() { return Err(Error::WrongBatchResponseSize); } @@ -162,8 +214,10 @@ impl Client { } } // Match responses to the requests. - let results = - requests.into_iter().map(|r| resp_by_id.remove(&HashableValue(&r.id))).collect(); + let results = requests + .into_iter() + .map(|r| resp_by_id.remove(&HashableValue(&r.id))) + .collect(); // Since we're also just producing the first duplicate ID, we can also just produce the // first incorrect ID in case there are multiple. @@ -173,26 +227,120 @@ impl Client { Ok(results) } +} - /// Builds a request - pub fn build_request<'a, 'b>( +#[cfg(feature = "async")] +impl Client { + /// Make a request and deserialize the response + pub async fn do_rpc_async serde::de::Deserialize<'a>>( &self, - name: &'a str, - params: &'b [serde_json::Value], - ) -> Request<'a, 'b> { - let mut nonce = self.nonce.lock().unwrap(); - *nonce += 1; - Request { - method: name, - params: params, - id: From::from(*nonce), - jsonrpc: Some("2.0"), + rpc_name: &str, + args: &[serde_json::value::Value], + ) -> Result { + let request = self.build_request(rpc_name, args); + let response = self.send_request_async(&request).await?; + Ok(response.into_result()?) + } + + /// The actual send logic used by both [send_request] and [send_batch]. + async fn send_raw_async(&self, body: &B) -> Result + where + B: serde::ser::Serialize, + R: for<'de> serde::de::Deserialize<'de>, + { + // Build request + let request_raw = serde_json::to_vec(body)?; + + // Send request + let mut request_builder = http::Request::post(&self.url); + request_builder.header("Content-Type", "application/json-rpc"); + + // Set Authorization header + if let Some(ref user) = self.user { + let mut auth = user.clone(); + auth.push(':'); + if let Some(ref pass) = self.pass { + auth.push_str(&pass[..]); + } + let value = format!("Basic {}", &base64::encode(auth.as_bytes())); + request_builder.header("Authorization", value); + } + + // Errors only on invalid header or builder reuse. + let http_request = request_builder.body(&request_raw[..]).unwrap(); + + let http_response = self + .roundtripper + .request(http_request) + .await + .map_err(|e| Error::Http(Box::new(e)))?; + + // nb we ignore stream.status since we expect the body + // to contain information about any error + Ok(serde_json::from_reader(http_response.into_body())?) + } + + /// Sends a request to a client + pub async fn send_request_async<'a, 'b>( + &self, + request: &Request<'a, 'b>, + ) -> Result { + let response: Response = self.send_raw_async(&request).await?; + + if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) { + return Err(Error::VersionMismatch); + } + if response.id != request.id { + return Err(Error::NonceMismatch); } + Ok(response) } - /// Accessor for the last-used nonce - pub fn last_nonce(&self) -> u64 { - *self.nonce.lock().unwrap() + /// Sends a batch of requests to the client. The return vector holds the response + /// for the request at the corresponding index. If no response was provided, it's [None]. + /// + /// Note that the requests need to have valid IDs, so it is advised to create the requests + /// with [build_request]. + pub async fn send_batch_async<'a, 'b>( + &self, + requests: &[Request<'a, 'b>], + ) -> Result>, Error> { + if requests.len() < 1 { + return Err(Error::EmptyBatch); + } + + // If the request body is invalid JSON, the response is a single response object. + // We ignore this case since we are confident we are producing valid JSON. + let responses: Vec = self.send_raw_async(&requests).await?; + + if responses.len() > requests.len() { + return Err(Error::WrongBatchResponseSize); + } + + // To prevent having to clone responses, we first copy all the IDs so we can reference + // them easily. IDs can only be of JSON type String or Number (or Null), so cloning + // should be inexpensive and require no allocations as Numbers are more common. + let ids: Vec = responses.iter().map(|r| r.id.clone()).collect(); + // First index responses by ID and catch duplicate IDs. + let mut resp_by_id = HashMap::new(); + for (id, resp) in ids.iter().zip(responses.into_iter()) { + if let Some(dup) = resp_by_id.insert(HashableValue(&id), resp) { + return Err(Error::BatchDuplicateResponseId(dup.id)); + } + } + // Match responses to the requests. + let results = requests + .into_iter() + .map(|r| resp_by_id.remove(&HashableValue(&r.id))) + .collect(); + + // Since we're also just producing the first duplicate ID, we can also just produce the + // first incorrect ID in case there are multiple. + if let Some(incorrect) = resp_by_id.into_iter().nth(0) { + return Err(Error::WrongBatchResponseId(incorrect.1.id)); + } + + Ok(results) } } @@ -206,12 +354,31 @@ mod tests { type ResponseBody = io::Empty; type Err = io::Error; + #[cfg(not(feature = "async"))] fn request( &self, _: http::Request<&[u8]>, ) -> Result, Self::Err> { Err(io::ErrorKind::Other.into()) } + + #[cfg(feature = "async")] + fn request<'life>( + &'life self, + request: http::Request<&[u8]>, + ) -> std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result, Self::Err>, + > + Send + + 'life, + >, + > + where + Self: Sync + 'life, + { + Box::pin(async { Err(io::ErrorKind::Other.into()) }) + } } #[test] diff --git a/src/error.rs b/src/error.rs index 5ff19119..8448c6c9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -21,7 +21,7 @@ use std::{error, fmt}; use serde_json; -use Response; +use crate::Response; /// A library error #[derive(Debug)] diff --git a/src/util.rs b/src/util.rs index 482a069e..9f2aceb8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -44,18 +44,18 @@ impl<'a> Hash for HashableValue<'a> { } else { n.to_string().hash(state); } - }, + } Value::String(ref s) => { "string".hash(state); s.hash(state); - }, + } Value::Array(ref v) => { "array".hash(state); v.len().hash(state); for obj in v { HashableValue(obj).hash(state); } - }, + } Value::Object(ref m) => { "object".hash(state); m.len().hash(state); @@ -116,5 +116,3 @@ mod tests { assert!(coll.contains(&m)); } } - -