diff --git a/crates/engine/tree/Cargo.toml b/crates/engine/tree/Cargo.toml index 6a6a67a5e36b..b5b8fc743645 100644 --- a/crates/engine/tree/Cargo.toml +++ b/crates/engine/tree/Cargo.toml @@ -32,7 +32,6 @@ reth-prune.workspace = true reth-revm.workspace = true reth-stages-api.workspace = true reth-tasks.workspace = true -reth-trie-db.workspace = true reth-trie-parallel.workspace = true reth-trie-sparse.workspace = true reth-trie.workspace = true @@ -82,6 +81,7 @@ reth-stages = { workspace = true, features = ["test-utils"] } reth-static-file.workspace = true reth-testing-utils.workspace = true reth-tracing.workspace = true +reth-trie-db.workspace = true # alloy alloy-rlp.workspace = true @@ -120,6 +120,6 @@ test-utils = [ "reth-static-file", "reth-tracing", "reth-trie/test-utils", - "reth-prune-types?/test-utils", "reth-trie-db/test-utils", + "reth-prune-types?/test-utils", ] diff --git a/crates/engine/tree/benches/state_root_task.rs b/crates/engine/tree/benches/state_root_task.rs index f6a6a4adce78..93f1ff4ec330 100644 --- a/crates/engine/tree/benches/state_root_task.rs +++ b/crates/engine/tree/benches/state_root_task.rs @@ -145,10 +145,34 @@ fn bench_state_root(c: &mut Criterion) { let nodes_sorted = config.nodes_sorted.clone(); let state_sorted = config.state_sorted.clone(); let prefix_sets = config.prefix_sets.clone(); - - (config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets) + let num_threads = std::thread::available_parallelism() + .map_or(1, |num| (num.get() / 2).max(1)); + + let state_root_task_pool = rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("proof-worker-{}", i)) + .build() + .expect("Failed to create proof worker thread pool"); + + ( + config, + state_updates, + provider, + nodes_sorted, + state_sorted, + prefix_sets, + state_root_task_pool, + ) }, - |(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)| { + |( + config, + state_updates, + provider, + nodes_sorted, + state_sorted, + prefix_sets, + state_root_task_pool, + )| { let blinded_provider_factory = ProofBlindedProviderFactory::new( InMemoryTrieCursorFactory::new( DatabaseTrieCursorFactory::new(provider.tx_ref()), @@ -162,7 +186,11 @@ fn bench_state_root(c: &mut Criterion) { ); black_box(std::thread::scope(|scope| { - let task = StateRootTask::new(config, blinded_provider_factory); + let task = StateRootTask::new( + config, + blinded_provider_factory, + &state_root_task_pool, + ); let mut hook = task.state_hook(); let handle = task.spawn(scope); diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index cb64d95d8f92..dd931a8e40b8 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -6,20 +6,15 @@ use rayon::iter::{ParallelBridge, ParallelIterator}; use reth_errors::{ProviderError, ProviderResult}; use reth_evm::system_calls::OnStateHook; use reth_provider::{ - providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, - StateCommitmentProvider, + providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, StateCommitmentProvider, }; use reth_trie::{ - hashed_cursor::HashedPostStateCursorFactory, prefix_set::TriePrefixSetsMut, - proof::Proof, - trie_cursor::InMemoryTrieCursorFactory, updates::{TrieUpdates, TrieUpdatesSorted}, HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles, TrieInput, }; -use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory}; -use reth_trie_parallel::root::ParallelStateRootError; +use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError}; use reth_trie_sparse::{ blinded::{BlindedProvider, BlindedProviderFactory}, errors::{SparseStateTrieError, SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind}, @@ -269,7 +264,7 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState { /// to the tree. /// Then it updates relevant leaves according to the result of the transaction. #[derive(Debug)] -pub struct StateRootTask { +pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> { /// Task configuration. config: StateRootConfig, /// Receiver for state root related messages. @@ -283,10 +278,12 @@ pub struct StateRootTask { /// The sparse trie used for the state root calculation. If [`None`], then update is in /// progress. sparse_trie: Option>>, + /// Reference to the shared thread pool for parallel proof generation + thread_pool: &'env rayon::ThreadPool, } #[allow(dead_code)] -impl<'env, Factory, ABP, SBP, BPF> StateRootTask +impl<'env, Factory, ABP, SBP, BPF> StateRootTask<'env, Factory, BPF> where Factory: DatabaseProviderFactory + StateCommitmentProvider @@ -302,7 +299,11 @@ where + 'env, { /// Creates a new state root task with the unified message channel - pub fn new(config: StateRootConfig, blinded_provider: BPF) -> Self { + pub fn new( + config: StateRootConfig, + blinded_provider: BPF, + thread_pool: &'env rayon::ThreadPool, + ) -> Self { let (tx, rx) = channel(); Self { @@ -312,6 +313,7 @@ where fetched_proof_targets: Default::default(), proof_sequencer: ProofSequencer::new(), sparse_trie: Some(Box::new(SparseStateTrie::new(blinded_provider).with_updates(true))), + thread_pool, } } @@ -350,6 +352,7 @@ where fetched_proof_targets: &mut MultiProofTargets, proof_sequence_number: u64, state_root_message_sender: Sender>, + thread_pool: &'env rayon::ThreadPool, ) { let proof_targets = targets.into_iter().map(|address| (keccak256(address), Default::default())).collect(); @@ -362,6 +365,7 @@ where proof_targets, proof_sequence_number, state_root_message_sender, + thread_pool, ); } @@ -375,6 +379,7 @@ where fetched_proof_targets: &mut MultiProofTargets, proof_sequence_number: u64, state_root_message_sender: Sender>, + thread_pool: &'env rayon::ThreadPool, ) { let hashed_state_update = evm_state_to_hashed_post_state(update); @@ -388,6 +393,7 @@ where proof_targets, proof_sequence_number, state_root_message_sender, + thread_pool, ); } @@ -398,22 +404,27 @@ where proof_targets: MultiProofTargets, proof_sequence_number: u64, state_root_message_sender: Sender>, + thread_pool: &'env rayon::ThreadPool, ) { // Dispatch proof gathering for this state update - scope.spawn(move |_| match calculate_multiproof(config, proof_targets.clone()) { - Ok(proof) => { - let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( - Box::new(ProofCalculated { - state_update: hashed_state_update, - targets: proof_targets, - proof, - sequence_number: proof_sequence_number, - }), - )); - } - Err(error) => { - let _ = - state_root_message_sender.send(StateRootMessage::ProofCalculationError(error)); + scope.spawn(move |_| { + let result = calculate_multiproof(thread_pool, config, proof_targets.clone()); + + match result { + Ok(proof) => { + let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( + Box::new(ProofCalculated { + state_update: hashed_state_update, + targets: proof_targets, + proof, + sequence_number: proof_sequence_number, + }), + )); + } + Err(error) => { + let _ = state_root_message_sender + .send(StateRootMessage::ProofCalculationError(error)); + } } }); } @@ -517,6 +528,7 @@ where &mut self.fetched_proof_targets, self.proof_sequencer.next_sequence(), self.tx.clone(), + self.thread_pool, ); } StateRootMessage::StateUpdate(update) => { @@ -540,6 +552,7 @@ where &mut self.fetched_proof_targets, self.proof_sequencer.next_sequence(), self.tx.clone(), + self.thread_pool, ); } StateRootMessage::FinishedStateUpdates => { @@ -717,26 +730,23 @@ fn get_proof_targets( /// Calculate multiproof for the targets. #[inline] fn calculate_multiproof( + thread_pool: &rayon::ThreadPool, config: StateRootConfig, proof_targets: MultiProofTargets, ) -> ProviderResult where - Factory: DatabaseProviderFactory + StateCommitmentProvider, + Factory: + DatabaseProviderFactory + StateCommitmentProvider + Clone + 'static, { - let provider = config.consistent_view.provider_ro()?; - - Ok(Proof::from_tx(provider.tx_ref()) - .with_trie_cursor_factory(InMemoryTrieCursorFactory::new( - DatabaseTrieCursorFactory::new(provider.tx_ref()), - &config.nodes_sorted, - )) - .with_hashed_cursor_factory(HashedPostStateCursorFactory::new( - DatabaseHashedCursorFactory::new(provider.tx_ref()), - &config.state_sorted, - )) - .with_prefix_sets_mut(config.prefix_sets.as_ref().clone()) - .with_branch_node_hash_masks(true) - .multiproof(proof_targets)?) + Ok(ParallelProof::new( + config.consistent_view, + config.nodes_sorted, + config.state_sorted, + config.prefix_sets, + thread_pool, + ) + .with_branch_node_hash_masks(true) + .multiproof(proof_targets)?) } /// Updates the sparse trie with the given proofs and state, and returns the updated trie and the @@ -967,8 +977,17 @@ mod tests { ), config.prefix_sets.clone(), ); + let num_threads = + std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1)); + + let state_root_task_pool = rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("proof-worker-{}", i)) + .build() + .expect("Failed to create proof worker thread pool"); + let (root_from_task, _) = std::thread::scope(|std_scope| { - let task = StateRootTask::new(config, blinded_provider_factory); + let task = StateRootTask::new(config, blinded_provider_factory, &state_root_task_pool); let mut state_hook = task.state_hook(); let handle = task.spawn(std_scope); diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 149a53a1e4b2..ef7e34b1970a 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -125,7 +125,7 @@ where ); // Pre-calculate storage roots for accounts which were changed. - tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); + tracker.set_precomputed_storage_roots(storage_root_targets_len as u64); let mut storage_proofs = B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());