diff --git a/src/electrum.rs b/src/electrum.rs index d815d63a3..fed999454 100644 --- a/src/electrum.rs +++ b/src/electrum.rs @@ -10,7 +10,9 @@ use serde_derive::Deserialize; use serde_json::{self, json, Value}; use std::collections::{hash_map::Entry, HashMap}; +use std::fmt; use std::iter::FromIterator; +use std::str::FromStr; use crate::{ cache::Cache, @@ -54,9 +56,9 @@ enum Requests { #[derive(Deserialize, Debug, PartialEq, Eq)] #[serde(untagged)] -enum Version { +enum VersionRequest { Single(String), - Range(String, String), + MinMax(String, String), } #[derive(Deserialize)] @@ -437,20 +439,13 @@ impl Rpc { format!("electrs/{}", ELECTRS_VERSION) } - fn version(&self, (client_id, client_version): &(String, Version)) -> Result { + fn version(&self, (client_id, client_version): &(String, VersionRequest)) -> Result { match client_version { - Version::Single(v) if v == PROTOCOL_VERSION => { - Ok(json!([self.server_id(), PROTOCOL_VERSION])) - } - _ => { - bail!( - "{} requested {:?}, server supports {}", - client_id, - client_version, - PROTOCOL_VERSION - ); - } + VersionRequest::Single(exact) => check_between(PROTOCOL_VERSION, exact, exact), + VersionRequest::MinMax(min, max) => check_between(PROTOCOL_VERSION, min, max), } + .with_context(|| format!("unsupported request {:?} by {}", client_version, client_id))?; + Ok(json!([self.server_id(), PROTOCOL_VERSION])) } fn features(&self) -> Result { @@ -594,7 +589,7 @@ enum Params { TransactionGet(TxGetArgs), TransactionGetMerkle((Txid, usize)), TransactionFromPosition((usize, usize, bool)), - Version((String, Version)), + Version((String, VersionRequest)), } impl Params { @@ -726,3 +721,66 @@ fn parse_requests(line: &str) -> Result { } } } + +fn parse_version(version: &str) -> Result { + let result = version + .split('.') + .map(|part| usize::from_str(part).with_context(|| format!("invalid version {}", version))) + .collect::>>()?; + Ok(Version(result)) +} + +#[derive(PartialOrd, PartialEq, Debug)] +struct Version(Vec); + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for (i, v) in self.0.iter().enumerate() { + if i > 0 { + write!(f, ".")?; + } + write!(f, "{}", v)?; + } + Ok(()) + } +} + +fn check_between(version_str: &str, min_str: &str, max_str: &str) -> Result<()> { + let version = parse_version(version_str)?; + let min = parse_version(min_str)?; + if version < min { + bail!("version {} < {}", version, min); + } + let max = parse_version(max_str)?; + if version > max { + bail!("version {} > {}", version, max); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::{check_between, parse_version, Version}; + + #[test] + fn test_version() { + assert_eq!(parse_version("1").unwrap(), Version(vec![1])); + assert_eq!(parse_version("1.2").unwrap(), Version(vec![1, 2])); + assert_eq!(parse_version("1.2.345").unwrap(), Version(vec![1, 2, 345])); + + assert!(parse_version("1.2").unwrap() < parse_version("1.100").unwrap()); + } + + #[test] + fn test_between() { + assert!(check_between("1.4", "1.4", "1.4").is_ok()); + assert!(check_between("1.4", "1.4", "1.5").is_ok()); + assert!(check_between("1.4", "1.3", "1.4").is_ok()); + assert!(check_between("1.4", "1.3", "1.5").is_ok()); + + assert!(check_between("1.4", "1.5", "1.5").is_err()); + assert!(check_between("1.4", "1.3", "1.3").is_err()); + assert!(check_between("1.4", "1.4.1", "1.5").is_err()); + assert!(check_between("1.4", "1", "1").is_err()); + } +}