diff --git a/Cargo.lock b/Cargo.lock index 377f70c..25e5fde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -344,6 +344,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +[[package]] +name = "futures-io" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" + [[package]] name = "futures-sink" version = "0.3.28" @@ -363,9 +369,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-core", + "futures-io", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1303,21 +1312,9 @@ dependencies = [ "num_cpus", "pin-project-lite", "socket2 0.5.3", - "tokio-macros", "windows-sys", ] -[[package]] -name = "tokio-macros" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.31", -] - [[package]] name = "tokio-native-tls" version = "0.3.1" @@ -1396,7 +1393,6 @@ dependencies = [ "reqwest", "sha2", "tempfile", - "tokio", "update-format-crau", "url", "uuid", diff --git a/Cargo.toml b/Cargo.toml index 8375a8a..5dda2de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,9 @@ env_logger = "0.10" globset = "0.4" log = "0.4" protobuf = "3.2.0" -reqwest = "0.11" +reqwest = { version = "0.11", features = ["blocking"] } sha2 = "0.10" tempfile = "3.8.1" -tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } url = "2" uuid = "1.2" diff --git a/examples/download_test.rs b/examples/download_test.rs index bdab4f7..ba2f34b 100644 --- a/examples/download_test.rs +++ b/examples/download_test.rs @@ -4,16 +4,17 @@ use std::str::FromStr; use ue_rs::download_and_hash; -#[tokio::main] -async fn main() -> Result<(), Box> { - let client = reqwest::Client::new(); +fn main() -> Result<(), Box> { + let client = reqwest::blocking::Client::new(); let url = Url::from_str(std::env::args().nth(1).expect("missing URL (second argument)").as_str())?; println!("fetching {}...", url); - let data = Vec::new(); - let res = download_and_hash(&client, url, data, false).await?; + let tempdir = tempfile::tempdir()?; + let path = tempdir.path().join("tmpfile"); + let res = download_and_hash(&client, url, &path, false)?; + tempdir.close()?; println!("hash: {}", res.hash); diff --git a/examples/full_test.rs b/examples/full_test.rs index 4a614b2..2992a6a 100644 --- a/examples/full_test.rs +++ b/examples/full_test.rs @@ -41,9 +41,8 @@ fn get_pkgs_to_download(resp: &omaha::Response) -> Result Result<(), Box> { - let client = reqwest::Client::new(); +fn main() -> Result<(), Box> { + let client = reqwest::blocking::Client::new(); const APP_VERSION_DEFAULT: &str = "3340.0.0+nightly-20220823-2100"; const MACHINE_ID_DEFAULT: &str = "abce671d61774703ac7be60715220bfe"; @@ -59,7 +58,7 @@ async fn main() -> Result<(), Box> { track: Cow::Borrowed(TRACK_DEFAULT), }; - let response_text = ue_rs::request::perform(&client, parameters).await.context(format!( + let response_text = ue_rs::request::perform(&client, parameters).context(format!( "perform({APP_VERSION_DEFAULT}, {MACHINE_ID_DEFAULT}, {TRACK_DEFAULT}) failed" ))?; @@ -79,11 +78,10 @@ async fn main() -> Result<(), Box> { for (url, expected_sha256) in pkgs_to_dl { println!("downloading {}...", url); - // TODO: use a file or anything that implements std::io::Write here. - // std::io::BufWriter wrapping an std::fs::File is probably the right choice. - // std::io::sink() is basically just /dev/null - let data = std::io::sink(); - let res = ue_rs::download_and_hash(&client, url.clone(), data, false).await.context(format!("download_and_hash({url:?}) failed"))?; + let tempdir = tempfile::tempdir()?; + let path = tempdir.path().join("tmpfile"); + let res = ue_rs::download_and_hash(&client, url.clone(), &path, false).context(format!("download_and_hash({url:?}) failed"))?; + tempdir.close()?; println!("\texpected sha256: {}", expected_sha256); println!("\tcalculated sha256: {}", res.hash); diff --git a/examples/request.rs b/examples/request.rs index 9468e46..d3b41f8 100644 --- a/examples/request.rs +++ b/examples/request.rs @@ -5,9 +5,8 @@ use anyhow::Context; use ue_rs::request; -#[tokio::main] -async fn main() -> Result<(), Box> { - let client = reqwest::Client::new(); +fn main() -> Result<(), Box> { + let client = reqwest::blocking::Client::new(); const APP_VERSION_DEFAULT: &str = "3340.0.0+nightly-20220823-2100"; const MACHINE_ID_DEFAULT: &str = "abce671d61774703ac7be60715220bfe"; @@ -20,7 +19,7 @@ async fn main() -> Result<(), Box> { track: Cow::Borrowed(TRACK_DEFAULT), }; - let response = request::perform(&client, parameters).await.context(format!( + let response = request::perform(&client, parameters).context(format!( "perform({APP_VERSION_DEFAULT}, {MACHINE_ID_DEFAULT}, {TRACK_DEFAULT}) failed" ))?; diff --git a/src/bin/download_sysext.rs b/src/bin/download_sysext.rs index 5eadb8b..1f9ab6b 100644 --- a/src/bin/download_sysext.rs +++ b/src/bin/download_sysext.rs @@ -16,7 +16,7 @@ use argh::FromArgs; use globset::{Glob, GlobSet, GlobSetBuilder}; use hard_xml::XmlRead; use omaha::FileSize; -use reqwest::Client; +use reqwest::blocking::Client; use reqwest::redirect::Policy; use url::Url; @@ -94,7 +94,7 @@ impl<'a> Package<'a> { Ok(()) } - async fn download(&mut self, into_dir: &Path, client: &reqwest::Client, print_progress: bool) -> Result<()> { + fn download(&mut self, into_dir: &Path, client: &Client, print_progress: bool) -> Result<()> { // FIXME: use _range_start for completing downloads let _range_start = match self.status { PackageStatus::ToDownload => 0, @@ -105,9 +105,7 @@ impl<'a> Package<'a> { info!("downloading {}...", self.url); let path = into_dir.join(&*self.name); - let mut file = File::create(path.clone()).context(format!("failed to create path ({:?})", path.display()))?; - - let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file, print_progress).await { + let res = match ue_rs::download_and_hash(client, self.url.clone(), &path, print_progress) { Ok(ok) => ok, Err(err) => { error!("Downloading failed with error {}", err); @@ -243,28 +241,26 @@ fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet) } // Read data from remote URL into File -async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result> +fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result> where U: reqwest::IntoUrl + From + std::clone::Clone + std::fmt::Debug, Url: From, { - let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?; - - ue_rs::download_and_hash(client, input_url.clone(), &mut file, print_progress).await.context(format!("unable to download data(url {:?})", input_url))?; + let r = ue_rs::download_and_hash(client, input_url.clone(), path, print_progress).context(format!("unable to download data(url {:?})", input_url))?; Ok(Package { name: Cow::Borrowed(path.file_name().unwrap_or(OsStr::new("fakepackage")).to_str().unwrap_or("fakepackage")), - hash: hash_on_disk_sha256(path, None)?, - size: FileSize::from_bytes(file.metadata().context(format!("failed to get metadata, path ({:?})", path.display()))?.len() as usize), + hash: r.hash, + size: FileSize::from_bytes(r.data.metadata().context(format!("failed to get metadata, path ({:?})", path.display()))?.len() as usize), url: input_url.into(), status: PackageStatus::Unverified, }) } -async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> { +fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> { pkg.check_download(unverified_dir)?; - pkg.download(unverified_dir, client, print_progress).await.context(format!("unable to download \"{:?}\"", pkg.name))?; + pkg.download(unverified_dir, client, print_progress).context(format!("unable to download \"{:?}\"", pkg.name))?; // Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz". // Verified payload is stored in e.g. "output_dir/oem.raw". @@ -322,8 +318,7 @@ impl Args { } } -#[tokio::main] -async fn main() -> Result<(), Box> { +fn main() -> Result<(), Box> { env_logger::init(); let args: Args = argh::from_env(); @@ -374,8 +369,7 @@ async fn main() -> Result<(), Box> { Url::from_str(url.as_str()).context(anyhow!("failed to convert into url ({:?})", url))?, &client, args.print_progress, - ) - .await?; + )?; do_download_verify( &mut pkg_fake, output_dir, @@ -383,8 +377,7 @@ async fn main() -> Result<(), Box> { args.pubkey_file.as_str(), &client, args.print_progress, - ) - .await?; + )?; // verify only a fake package, early exit and skip the rest. return Ok(()); @@ -417,8 +410,7 @@ async fn main() -> Result<(), Box> { args.pubkey_file.as_str(), &client, args.print_progress, - ) - .await?; + )?; } // clean up data diff --git a/src/download.rs b/src/download.rs index d00b841..15f4324 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,18 +1,18 @@ use anyhow::{Context, Result, bail}; -use std::io::{BufReader, Read, Seek, SeekFrom, Write}; -use std::io; +use std::io::{BufReader, Read, Seek, SeekFrom}; use std::fs::File; use std::path::Path; use log::info; use url::Url; use reqwest::StatusCode; +use reqwest::blocking::Client; use sha2::{Sha256, Digest}; -pub struct DownloadResult { +pub struct DownloadResult { pub hash: omaha::Hash, - pub data: W, + pub data: File, } pub fn hash_on_disk_sha256(path: &Path, maxlen: Option) -> Result> { @@ -57,10 +57,9 @@ pub fn hash_on_disk_sha256(path: &Path, maxlen: Option) -> Result(client: &reqwest::Client, url: U, mut data: W, print_progress: bool) -> Result> +pub fn download_and_hash(client: &Client, url: U, path: &Path, print_progress: bool) -> Result where U: reqwest::IntoUrl + Clone, - W: io::Write, Url: From, { let client_url = url.clone(); @@ -68,7 +67,6 @@ where #[rustfmt::skip] let mut res = client.get(url) .send() - .await .context(format!("client get and send({:?}) failed", client_url.as_str()))?; // Redirect was already handled at this point, so there is no need to touch @@ -89,33 +87,14 @@ where } } - let mut hasher = Sha256::new(); - - let mut bytes_read = 0usize; - let bytes_to_read = res.content_length().unwrap_or(u64::MAX) as usize; - - while let Some(chunk) = res.chunk().await.context("failed to get response chunk")? { - bytes_read += chunk.len(); - - hasher.update(&chunk); - data.write_all(&chunk).context("failed to write_all chunk")?; - - if print_progress { - print!( - "\rread {}/{} ({:3}%)", - bytes_read, - bytes_to_read, - ((bytes_read as f32 / bytes_to_read as f32) * 100.0f32).floor() - ); - io::stdout().flush().context("failed to flush stdout")?; - } + if print_progress { + println!("writing to {}", path.display()); } - - data.flush().context("failed to flush data")?; - println!(); + let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?; + res.copy_to(&mut file)?; Ok(DownloadResult { - hash: omaha::Hash::from_bytes(hasher.finalize().into()), - data, + hash: hash_on_disk_sha256(path, None)?, + data: file, }) } diff --git a/src/request.rs b/src/request.rs index c5dc5ee..4adcf40 100644 --- a/src/request.rs +++ b/src/request.rs @@ -28,7 +28,7 @@ pub struct Parameters<'a> { pub machine_id: Cow<'a, str>, } -pub async fn perform<'a>(client: &reqwest::Client, parameters: Parameters<'a>) -> Result { +pub fn perform<'a>(client: &reqwest::blocking::Client, parameters: Parameters<'a>) -> Result { let req_body = { let r = omaha::Request { protocol_version: Cow::Borrowed(PROTOCOL_VERSION), @@ -78,8 +78,7 @@ pub async fn perform<'a>(client: &reqwest::Client, parameters: Parameters<'a>) - let resp = client.post(UPDATE_URL) .body(req_body) .send() - .await .context("client post send({UPDATE_URL}) failed")?; - resp.text().await.context("failed to get response") + resp.text().context("failed to get response") }