From 8244bc4794639d83640fe56dab1f396d66878769 Mon Sep 17 00:00:00 2001 From: Kai Lueke Date: Thu, 21 Dec 2023 15:40:46 +0100 Subject: [PATCH] Switch to reqwest::blocking and use a File object The use of async for the client was motivated by doing hashing while downloading and ideally everything else too so it would extract to disk. However, only the hashing was done and all other operations including and additional hash step were done on the file. Since the use of async brought more problems while not having a real benefit, switch to a blocking mode. Instead of opening the file and passing a write to the download function we now open the file inside the download function which will make it easier to retry downloads. --- Cargo.lock | 22 ++++++++----------- Cargo.toml | 3 +-- examples/download_test.rs | 11 +++++----- examples/full_test.rs | 16 +++++++------- examples/request.rs | 7 +++---- src/bin/download_sysext.rs | 34 ++++++++++++------------------ src/download.rs | 43 ++++++++++---------------------------- src/request.rs | 5 ++--- 8 files changed, 52 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 377f70c..25e5fde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -344,6 +344,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +[[package]] +name = "futures-io" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" + [[package]] name = "futures-sink" version = "0.3.28" @@ -363,9 +369,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-core", + "futures-io", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1303,21 +1312,9 @@ dependencies = [ "num_cpus", "pin-project-lite", "socket2 0.5.3", - "tokio-macros", "windows-sys", ] -[[package]] -name = "tokio-macros" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.31", -] - [[package]] name = "tokio-native-tls" version = "0.3.1" @@ -1396,7 +1393,6 @@ dependencies = [ "reqwest", "sha2", "tempfile", - "tokio", "update-format-crau", "url", "uuid", diff --git a/Cargo.toml b/Cargo.toml index 8375a8a..5dda2de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,9 @@ env_logger = "0.10" globset = "0.4" log = "0.4" protobuf = "3.2.0" -reqwest = "0.11" +reqwest = { version = "0.11", features = ["blocking"] } sha2 = "0.10" tempfile = "3.8.1" -tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } url = "2" uuid = "1.2" diff --git a/examples/download_test.rs b/examples/download_test.rs index bdab4f7..ba2f34b 100644 --- a/examples/download_test.rs +++ b/examples/download_test.rs @@ -4,16 +4,17 @@ use std::str::FromStr; use ue_rs::download_and_hash; -#[tokio::main] -async fn main() -> Result<(), Box> { - let client = reqwest::Client::new(); +fn main() -> Result<(), Box> { + let client = reqwest::blocking::Client::new(); let url = Url::from_str(std::env::args().nth(1).expect("missing URL (second argument)").as_str())?; println!("fetching {}...", url); - let data = Vec::new(); - let res = download_and_hash(&client, url, data, false).await?; + let tempdir = tempfile::tempdir()?; + let path = tempdir.path().join("tmpfile"); + let res = download_and_hash(&client, url, &path, false)?; + tempdir.close()?; println!("hash: {}", res.hash); diff --git a/examples/full_test.rs b/examples/full_test.rs index 4a614b2..2992a6a 100644 --- a/examples/full_test.rs +++ b/examples/full_test.rs @@ -41,9 +41,8 @@ fn get_pkgs_to_download(resp: &omaha::Response) -> Result Result<(), Box> { - let client = reqwest::Client::new(); +fn main() -> Result<(), Box> { + let client = reqwest::blocking::Client::new(); const APP_VERSION_DEFAULT: &str = "3340.0.0+nightly-20220823-2100"; const MACHINE_ID_DEFAULT: &str = "abce671d61774703ac7be60715220bfe"; @@ -59,7 +58,7 @@ async fn main() -> Result<(), Box> { track: Cow::Borrowed(TRACK_DEFAULT), }; - let response_text = ue_rs::request::perform(&client, parameters).await.context(format!( + let response_text = ue_rs::request::perform(&client, parameters).context(format!( "perform({APP_VERSION_DEFAULT}, {MACHINE_ID_DEFAULT}, {TRACK_DEFAULT}) failed" ))?; @@ -79,11 +78,10 @@ async fn main() -> Result<(), Box> { for (url, expected_sha256) in pkgs_to_dl { println!("downloading {}...", url); - // TODO: use a file or anything that implements std::io::Write here. - // std::io::BufWriter wrapping an std::fs::File is probably the right choice. - // std::io::sink() is basically just /dev/null - let data = std::io::sink(); - let res = ue_rs::download_and_hash(&client, url.clone(), data, false).await.context(format!("download_and_hash({url:?}) failed"))?; + let tempdir = tempfile::tempdir()?; + let path = tempdir.path().join("tmpfile"); + let res = ue_rs::download_and_hash(&client, url.clone(), &path, false).context(format!("download_and_hash({url:?}) failed"))?; + tempdir.close()?; println!("\texpected sha256: {}", expected_sha256); println!("\tcalculated sha256: {}", res.hash); diff --git a/examples/request.rs b/examples/request.rs index 9468e46..d3b41f8 100644 --- a/examples/request.rs +++ b/examples/request.rs @@ -5,9 +5,8 @@ use anyhow::Context; use ue_rs::request; -#[tokio::main] -async fn main() -> Result<(), Box> { - let client = reqwest::Client::new(); +fn main() -> Result<(), Box> { + let client = reqwest::blocking::Client::new(); const APP_VERSION_DEFAULT: &str = "3340.0.0+nightly-20220823-2100"; const MACHINE_ID_DEFAULT: &str = "abce671d61774703ac7be60715220bfe"; @@ -20,7 +19,7 @@ async fn main() -> Result<(), Box> { track: Cow::Borrowed(TRACK_DEFAULT), }; - let response = request::perform(&client, parameters).await.context(format!( + let response = request::perform(&client, parameters).context(format!( "perform({APP_VERSION_DEFAULT}, {MACHINE_ID_DEFAULT}, {TRACK_DEFAULT}) failed" ))?; diff --git a/src/bin/download_sysext.rs b/src/bin/download_sysext.rs index 5eadb8b..1f9ab6b 100644 --- a/src/bin/download_sysext.rs +++ b/src/bin/download_sysext.rs @@ -16,7 +16,7 @@ use argh::FromArgs; use globset::{Glob, GlobSet, GlobSetBuilder}; use hard_xml::XmlRead; use omaha::FileSize; -use reqwest::Client; +use reqwest::blocking::Client; use reqwest::redirect::Policy; use url::Url; @@ -94,7 +94,7 @@ impl<'a> Package<'a> { Ok(()) } - async fn download(&mut self, into_dir: &Path, client: &reqwest::Client, print_progress: bool) -> Result<()> { + fn download(&mut self, into_dir: &Path, client: &Client, print_progress: bool) -> Result<()> { // FIXME: use _range_start for completing downloads let _range_start = match self.status { PackageStatus::ToDownload => 0, @@ -105,9 +105,7 @@ impl<'a> Package<'a> { info!("downloading {}...", self.url); let path = into_dir.join(&*self.name); - let mut file = File::create(path.clone()).context(format!("failed to create path ({:?})", path.display()))?; - - let res = match ue_rs::download_and_hash(client, self.url.clone(), &mut file, print_progress).await { + let res = match ue_rs::download_and_hash(client, self.url.clone(), &path, print_progress) { Ok(ok) => ok, Err(err) => { error!("Downloading failed with error {}", err); @@ -243,28 +241,26 @@ fn get_pkgs_to_download<'a>(resp: &'a omaha::Response, glob_set: &GlobSet) } // Read data from remote URL into File -async fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result> +fn fetch_url_to_file<'a, U>(path: &'a Path, input_url: U, client: &'a Client, print_progress: bool) -> Result> where U: reqwest::IntoUrl + From + std::clone::Clone + std::fmt::Debug, Url: From, { - let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?; - - ue_rs::download_and_hash(client, input_url.clone(), &mut file, print_progress).await.context(format!("unable to download data(url {:?})", input_url))?; + let r = ue_rs::download_and_hash(client, input_url.clone(), path, print_progress).context(format!("unable to download data(url {:?})", input_url))?; Ok(Package { name: Cow::Borrowed(path.file_name().unwrap_or(OsStr::new("fakepackage")).to_str().unwrap_or("fakepackage")), - hash: hash_on_disk_sha256(path, None)?, - size: FileSize::from_bytes(file.metadata().context(format!("failed to get metadata, path ({:?})", path.display()))?.len() as usize), + hash: r.hash, + size: FileSize::from_bytes(r.data.metadata().context(format!("failed to get metadata, path ({:?})", path.display()))?.len() as usize), url: input_url.into(), status: PackageStatus::Unverified, }) } -async fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> { +fn do_download_verify(pkg: &mut Package<'_>, output_dir: &Path, unverified_dir: &Path, pubkey_file: &str, client: &Client, print_progress: bool) -> Result<()> { pkg.check_download(unverified_dir)?; - pkg.download(unverified_dir, client, print_progress).await.context(format!("unable to download \"{:?}\"", pkg.name))?; + pkg.download(unverified_dir, client, print_progress).context(format!("unable to download \"{:?}\"", pkg.name))?; // Unverified payload is stored in e.g. "output_dir/.unverified/oem.gz". // Verified payload is stored in e.g. "output_dir/oem.raw". @@ -322,8 +318,7 @@ impl Args { } } -#[tokio::main] -async fn main() -> Result<(), Box> { +fn main() -> Result<(), Box> { env_logger::init(); let args: Args = argh::from_env(); @@ -374,8 +369,7 @@ async fn main() -> Result<(), Box> { Url::from_str(url.as_str()).context(anyhow!("failed to convert into url ({:?})", url))?, &client, args.print_progress, - ) - .await?; + )?; do_download_verify( &mut pkg_fake, output_dir, @@ -383,8 +377,7 @@ async fn main() -> Result<(), Box> { args.pubkey_file.as_str(), &client, args.print_progress, - ) - .await?; + )?; // verify only a fake package, early exit and skip the rest. return Ok(()); @@ -417,8 +410,7 @@ async fn main() -> Result<(), Box> { args.pubkey_file.as_str(), &client, args.print_progress, - ) - .await?; + )?; } // clean up data diff --git a/src/download.rs b/src/download.rs index d00b841..15f4324 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,18 +1,18 @@ use anyhow::{Context, Result, bail}; -use std::io::{BufReader, Read, Seek, SeekFrom, Write}; -use std::io; +use std::io::{BufReader, Read, Seek, SeekFrom}; use std::fs::File; use std::path::Path; use log::info; use url::Url; use reqwest::StatusCode; +use reqwest::blocking::Client; use sha2::{Sha256, Digest}; -pub struct DownloadResult { +pub struct DownloadResult { pub hash: omaha::Hash, - pub data: W, + pub data: File, } pub fn hash_on_disk_sha256(path: &Path, maxlen: Option) -> Result> { @@ -57,10 +57,9 @@ pub fn hash_on_disk_sha256(path: &Path, maxlen: Option) -> Result(client: &reqwest::Client, url: U, mut data: W, print_progress: bool) -> Result> +pub fn download_and_hash(client: &Client, url: U, path: &Path, print_progress: bool) -> Result where U: reqwest::IntoUrl + Clone, - W: io::Write, Url: From, { let client_url = url.clone(); @@ -68,7 +67,6 @@ where #[rustfmt::skip] let mut res = client.get(url) .send() - .await .context(format!("client get and send({:?}) failed", client_url.as_str()))?; // Redirect was already handled at this point, so there is no need to touch @@ -89,33 +87,14 @@ where } } - let mut hasher = Sha256::new(); - - let mut bytes_read = 0usize; - let bytes_to_read = res.content_length().unwrap_or(u64::MAX) as usize; - - while let Some(chunk) = res.chunk().await.context("failed to get response chunk")? { - bytes_read += chunk.len(); - - hasher.update(&chunk); - data.write_all(&chunk).context("failed to write_all chunk")?; - - if print_progress { - print!( - "\rread {}/{} ({:3}%)", - bytes_read, - bytes_to_read, - ((bytes_read as f32 / bytes_to_read as f32) * 100.0f32).floor() - ); - io::stdout().flush().context("failed to flush stdout")?; - } + if print_progress { + println!("writing to {}", path.display()); } - - data.flush().context("failed to flush data")?; - println!(); + let mut file = File::create(path).context(format!("failed to create path ({:?})", path.display()))?; + res.copy_to(&mut file)?; Ok(DownloadResult { - hash: omaha::Hash::from_bytes(hasher.finalize().into()), - data, + hash: hash_on_disk_sha256(path, None)?, + data: file, }) } diff --git a/src/request.rs b/src/request.rs index c5dc5ee..4adcf40 100644 --- a/src/request.rs +++ b/src/request.rs @@ -28,7 +28,7 @@ pub struct Parameters<'a> { pub machine_id: Cow<'a, str>, } -pub async fn perform<'a>(client: &reqwest::Client, parameters: Parameters<'a>) -> Result { +pub fn perform<'a>(client: &reqwest::blocking::Client, parameters: Parameters<'a>) -> Result { let req_body = { let r = omaha::Request { protocol_version: Cow::Borrowed(PROTOCOL_VERSION), @@ -78,8 +78,7 @@ pub async fn perform<'a>(client: &reqwest::Client, parameters: Parameters<'a>) - let resp = client.post(UPDATE_URL) .body(req_body) .send() - .await .context("client post send({UPDATE_URL}) failed")?; - resp.text().await.context("failed to get response") + resp.text().context("failed to get response") }