Skip to content
This repository has been archived by the owner on Sep 4, 2024. It is now read-only.

Commit

Permalink
conditional compilation for async roundtripper
Browse files Browse the repository at this point in the history
Signed-off-by: Gregory Hill <[email protected]>
  • Loading branch information
gregdhill committed Oct 13, 2020
1 parent e651798 commit b20149d
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 38 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ documentation = "https://docs.rs/jsonrpc/"
description = "Rust support for the JSON-RPC 2.0 protocol"
keywords = [ "protocol", "json", "http", "jsonrpc" ]
readme = "README.md"
edition = "2018"

[features]
async = []

[lib]
name = "jsonrpc"
Expand Down
231 changes: 199 additions & 32 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
//! and parsing responses
//!
use std::{error, io};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::{error, io};

use serde;
use base64;
use http;
use serde;
use serde_json;

use super::{Request, Response};
use util::HashableValue;
use error::Error;
use crate::error::Error;
use crate::util::HashableValue;

/// An interface for an HTTP roundtripper that handles HTTP requests.
pub trait HttpRoundTripper {
Expand All @@ -38,30 +38,53 @@ pub trait HttpRoundTripper {
/// The type for errors generated by the roundtripper.
type Err: error::Error;

/// Make an HTTP request. In practice only POST request will be made.
/// Make a synchronous HTTP request. In practice only POST request will be made.
fn request(
&self,
http::Request<&[u8]>,
_request: http::Request<&[u8]>,
) -> Result<http::Response<Self::ResponseBody>, Self::Err>;
}

/// An interface for an asynchronous HTTP roundtripper that handles HTTP requests.
#[cfg(feature = "async")]
pub trait AsyncHttpRoundTripper {
/// The type of the http::Response body.
type ResponseBody: io::Read;
/// The type for errors generated by the roundtripper.
type Err: error::Error;

/// Make an asynchronous HTTP request. In practice only POST request will be made.
fn request<'life>(
&'life self,
_request: http::Request<&'life [u8]>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<http::Response<Self::ResponseBody>, Self::Err>>
+ Send
+ 'life,
>,
>
where
Self: Sync + 'life;
}

/// A handle to a remote JSONRPC server
pub struct Client<R: HttpRoundTripper> {
pub struct Client<R> {
url: String,
user: Option<String>,
pass: Option<String>,
roundtripper: R,
nonce: Arc<Mutex<u64>>,
}

impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
impl<R> Client<R> {
/// Creates a new client
pub fn new(
roundtripper: Rt,
roundtripper: R,
url: String,
user: Option<String>,
pass: Option<String>,
) -> Client<Rt> {
) -> Client<R> {
// Check that if we have a password, we have a username; other way around is ok
debug_assert!(pass.is_none() || user.is_some());

Expand All @@ -74,6 +97,29 @@ impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
}
}

/// Builds a request
pub fn build_request<'a, 'b>(
&self,
name: &'a str,
params: &'b [serde_json::Value],
) -> Request<'a, 'b> {
let mut nonce = self.nonce.lock().unwrap();
*nonce += 1;
Request {
method: name,
params: params,
id: From::from(*nonce),
jsonrpc: Some("2.0"),
}
}

/// Accessor for the last-used nonce
pub fn last_nonce(&self) -> u64 {
*self.nonce.lock().unwrap()
}
}

impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
/// Make a request and deserialize the response
pub fn do_rpc<T: for<'a> serde::de::Deserialize<'a>>(
&self,
Expand All @@ -82,7 +128,6 @@ impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
) -> Result<T, Error> {
let request = self.build_request(rpc_name, args);
let response = self.send_request(&request)?;

Ok(response.into_result()?)
}

Expand Down Expand Up @@ -113,17 +158,20 @@ impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
// Errors only on invalid header or builder reuse.
let http_request = request_builder.body(&request_raw[..]).unwrap();

let http_response =
self.roundtripper.request(http_request).map_err(|e| Error::Http(Box::new(e)))?;
let http_response = self
.roundtripper
.request(http_request)
.map_err(|e| Error::Http(Box::new(e)))?;

// nb we ignore stream.status since we expect the body
// to contain information about any error
Ok(serde_json::from_reader(http_response.into_body())?)
}

/// Sends a request to a client
pub fn send_request(&self, request: &Request) -> Result<Response, Error> {
pub fn send_request<'a, 'b>(&self, request: &Request<'a, 'b>) -> Result<Response, Error> {
let response: Response = self.send_raw(&request)?;

if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) {
return Err(Error::VersionMismatch);
}
Expand All @@ -138,14 +186,18 @@ impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
///
/// Note that the requests need to have valid IDs, so it is advised to create the requests
/// with [build_request].
pub fn send_batch(&self, requests: &[Request]) -> Result<Vec<Option<Response>>, Error> {
pub fn send_batch<'a, 'b>(
&self,
requests: &[Request<'a, 'b>],
) -> Result<Vec<Option<Response>>, Error> {
if requests.len() < 1 {
return Err(Error::EmptyBatch);
}

// If the request body is invalid JSON, the response is a single response object.
// We ignore this case since we are confident we are producing valid JSON.
let responses: Vec<Response> = self.send_raw(&requests)?;

if responses.len() > requests.len() {
return Err(Error::WrongBatchResponseSize);
}
Expand All @@ -162,8 +214,10 @@ impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
}
}
// Match responses to the requests.
let results =
requests.into_iter().map(|r| resp_by_id.remove(&HashableValue(&r.id))).collect();
let results = requests
.into_iter()
.map(|r| resp_by_id.remove(&HashableValue(&r.id)))
.collect();

// Since we're also just producing the first duplicate ID, we can also just produce the
// first incorrect ID in case there are multiple.
Expand All @@ -173,26 +227,120 @@ impl<Rt: HttpRoundTripper + 'static> Client<Rt> {

Ok(results)
}
}

/// Builds a request
pub fn build_request<'a, 'b>(
#[cfg(feature = "async")]
impl<Rt: AsyncHttpRoundTripper + 'static + Sync> Client<Rt> {
/// Make a request and deserialize the response
pub async fn do_rpc_async<T: for<'a> serde::de::Deserialize<'a>>(
&self,
name: &'a str,
params: &'b [serde_json::Value],
) -> Request<'a, 'b> {
let mut nonce = self.nonce.lock().unwrap();
*nonce += 1;
Request {
method: name,
params: params,
id: From::from(*nonce),
jsonrpc: Some("2.0"),
rpc_name: &str,
args: &[serde_json::value::Value],
) -> Result<T, Error> {
let request = self.build_request(rpc_name, args);
let response = self.send_request_async(&request).await?;
Ok(response.into_result()?)
}

/// The actual send logic used by both [send_request] and [send_batch].
async fn send_raw_async<B, R>(&self, body: &B) -> Result<R, Error>
where
B: serde::ser::Serialize,
R: for<'de> serde::de::Deserialize<'de>,
{
// Build request
let request_raw = serde_json::to_vec(body)?;

// Send request
let mut request_builder = http::Request::post(&self.url);
request_builder.header("Content-Type", "application/json-rpc");

// Set Authorization header
if let Some(ref user) = self.user {
let mut auth = user.clone();
auth.push(':');
if let Some(ref pass) = self.pass {
auth.push_str(&pass[..]);
}
let value = format!("Basic {}", &base64::encode(auth.as_bytes()));
request_builder.header("Authorization", value);
}

// Errors only on invalid header or builder reuse.
let http_request = request_builder.body(&request_raw[..]).unwrap();

let http_response = self
.roundtripper
.request(http_request)
.await
.map_err(|e| Error::Http(Box::new(e)))?;

// nb we ignore stream.status since we expect the body
// to contain information about any error
Ok(serde_json::from_reader(http_response.into_body())?)
}

/// Sends a request to a client
pub async fn send_request_async<'a, 'b>(
&self,
request: &Request<'a, 'b>,
) -> Result<Response, Error> {
let response: Response = self.send_raw_async(&request).await?;

if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) {
return Err(Error::VersionMismatch);
}
if response.id != request.id {
return Err(Error::NonceMismatch);
}
Ok(response)
}

/// Accessor for the last-used nonce
pub fn last_nonce(&self) -> u64 {
*self.nonce.lock().unwrap()
/// Sends a batch of requests to the client. The return vector holds the response
/// for the request at the corresponding index. If no response was provided, it's [None].
///
/// Note that the requests need to have valid IDs, so it is advised to create the requests
/// with [build_request].
pub async fn send_batch_async<'a, 'b>(
&self,
requests: &[Request<'a, 'b>],
) -> Result<Vec<Option<Response>>, Error> {
if requests.len() < 1 {
return Err(Error::EmptyBatch);
}

// If the request body is invalid JSON, the response is a single response object.
// We ignore this case since we are confident we are producing valid JSON.
let responses: Vec<Response> = self.send_raw_async(&requests).await?;

if responses.len() > requests.len() {
return Err(Error::WrongBatchResponseSize);
}

// To prevent having to clone responses, we first copy all the IDs so we can reference
// them easily. IDs can only be of JSON type String or Number (or Null), so cloning
// should be inexpensive and require no allocations as Numbers are more common.
let ids: Vec<serde_json::Value> = responses.iter().map(|r| r.id.clone()).collect();
// First index responses by ID and catch duplicate IDs.
let mut resp_by_id = HashMap::new();
for (id, resp) in ids.iter().zip(responses.into_iter()) {
if let Some(dup) = resp_by_id.insert(HashableValue(&id), resp) {
return Err(Error::BatchDuplicateResponseId(dup.id));
}
}
// Match responses to the requests.
let results = requests
.into_iter()
.map(|r| resp_by_id.remove(&HashableValue(&r.id)))
.collect();

// Since we're also just producing the first duplicate ID, we can also just produce the
// first incorrect ID in case there are multiple.
if let Some(incorrect) = resp_by_id.into_iter().nth(0) {
return Err(Error::WrongBatchResponseId(incorrect.1.id));
}

Ok(results)
}
}

Expand All @@ -206,12 +354,31 @@ mod tests {
type ResponseBody = io::Empty;
type Err = io::Error;

#[cfg(not(feature = "async"))]
fn request(
&self,
_: http::Request<&[u8]>,
) -> Result<http::Response<Self::ResponseBody>, Self::Err> {
Err(io::ErrorKind::Other.into())
}

#[cfg(feature = "async")]
fn request<'life>(
&'life self,
request: http::Request<&[u8]>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<http::Response<Self::ResponseBody>, Self::Err>,
> + Send
+ 'life,
>,
>
where
Self: Sync + 'life,
{
Box::pin(async { Err(io::ErrorKind::Other.into()) })
}
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::{error, fmt};

use serde_json;

use Response;
use crate::Response;

/// A library error
#[derive(Debug)]
Expand Down
8 changes: 3 additions & 5 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ impl<'a> Hash for HashableValue<'a> {
} else {
n.to_string().hash(state);
}
},
}
Value::String(ref s) => {
"string".hash(state);
s.hash(state);
},
}
Value::Array(ref v) => {
"array".hash(state);
v.len().hash(state);
for obj in v {
HashableValue(obj).hash(state);
}
},
}
Value::Object(ref m) => {
"object".hash(state);
m.len().hash(state);
Expand Down Expand Up @@ -116,5 +116,3 @@ mod tests {
assert!(coll.contains(&m));
}
}


0 comments on commit b20149d

Please sign in to comment.