Skip to content

Commit

Permalink
perf(engine): use ParallelProof::multiproof in StateRootTask (#13260)
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez authored Dec 19, 2024
1 parent 790a1e2 commit 0a0a2d4
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 47 deletions.
4 changes: 2 additions & 2 deletions crates/engine/tree/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ reth-prune.workspace = true
reth-revm.workspace = true
reth-stages-api.workspace = true
reth-tasks.workspace = true
reth-trie-db.workspace = true
reth-trie-parallel.workspace = true
reth-trie-sparse.workspace = true
reth-trie.workspace = true
Expand Down Expand Up @@ -82,6 +81,7 @@ reth-stages = { workspace = true, features = ["test-utils"] }
reth-static-file.workspace = true
reth-testing-utils.workspace = true
reth-tracing.workspace = true
reth-trie-db.workspace = true

# alloy
alloy-rlp.workspace = true
Expand Down Expand Up @@ -120,6 +120,6 @@ test-utils = [
"reth-static-file",
"reth-tracing",
"reth-trie/test-utils",
"reth-prune-types?/test-utils",
"reth-trie-db/test-utils",
"reth-prune-types?/test-utils",
]
36 changes: 32 additions & 4 deletions crates/engine/tree/benches/state_root_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,34 @@ fn bench_state_root(c: &mut Criterion) {
let nodes_sorted = config.nodes_sorted.clone();
let state_sorted = config.state_sorted.clone();
let prefix_sets = config.prefix_sets.clone();

(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)
let num_threads = std::thread::available_parallelism()
.map_or(1, |num| (num.get() / 2).max(1));

let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");

(
config,
state_updates,
provider,
nodes_sorted,
state_sorted,
prefix_sets,
state_root_task_pool,
)
},
|(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)| {
|(
config,
state_updates,
provider,
nodes_sorted,
state_sorted,
prefix_sets,
state_root_task_pool,
)| {
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
Expand All @@ -162,7 +186,11 @@ fn bench_state_root(c: &mut Criterion) {
);

black_box(std::thread::scope(|scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let task = StateRootTask::new(
config,
blinded_provider_factory,
&state_root_task_pool,
);
let mut hook = task.state_hook();
let handle = task.spawn(scope);

Expand Down
99 changes: 59 additions & 40 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@ use rayon::iter::{ParallelBridge, ParallelIterator};
use reth_errors::{ProviderError, ProviderResult};
use reth_evm::system_calls::OnStateHook;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
StateCommitmentProvider,
providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, StateCommitmentProvider,
};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::Proof,
trie_cursor::InMemoryTrieCursorFactory,
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles,
TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::ParallelStateRootError;
use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError};
use reth_trie_sparse::{
blinded::{BlindedProvider, BlindedProviderFactory},
errors::{SparseStateTrieError, SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind},
Expand Down Expand Up @@ -269,7 +264,7 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)]
pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> {
/// Task configuration.
config: StateRootConfig<Factory>,
/// Receiver for state root related messages.
Expand All @@ -283,10 +278,12 @@ pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
/// The sparse trie used for the state root calculation. If [`None`], then update is in
/// progress.
sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
/// Reference to the shared thread pool for parallel proof generation
thread_pool: &'env rayon::ThreadPool,
}

#[allow(dead_code)]
impl<'env, Factory, ABP, SBP, BPF> StateRootTask<Factory, BPF>
impl<'env, Factory, ABP, SBP, BPF> StateRootTask<'env, Factory, BPF>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
Expand All @@ -302,7 +299,11 @@ where
+ 'env,
{
/// Creates a new state root task with the unified message channel
pub fn new(config: StateRootConfig<Factory>, blinded_provider: BPF) -> Self {
pub fn new(
config: StateRootConfig<Factory>,
blinded_provider: BPF,
thread_pool: &'env rayon::ThreadPool,
) -> Self {
let (tx, rx) = channel();

Self {
Expand All @@ -312,6 +313,7 @@ where
fetched_proof_targets: Default::default(),
proof_sequencer: ProofSequencer::new(),
sparse_trie: Some(Box::new(SparseStateTrie::new(blinded_provider).with_updates(true))),
thread_pool,
}
}

Expand Down Expand Up @@ -350,6 +352,7 @@ where
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
) {
let proof_targets =
targets.into_iter().map(|address| (keccak256(address), Default::default())).collect();
Expand All @@ -362,6 +365,7 @@ where
proof_targets,
proof_sequence_number,
state_root_message_sender,
thread_pool,
);
}

Expand All @@ -375,6 +379,7 @@ where
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
) {
let hashed_state_update = evm_state_to_hashed_post_state(update);

Expand All @@ -388,6 +393,7 @@ where
proof_targets,
proof_sequence_number,
state_root_message_sender,
thread_pool,
);
}

Expand All @@ -398,22 +404,27 @@ where
proof_targets: MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
) {
// Dispatch proof gathering for this state update
scope.spawn(move |_| match calculate_multiproof(config, proof_targets.clone()) {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ =
state_root_message_sender.send(StateRootMessage::ProofCalculationError(error));
scope.spawn(move |_| {
let result = calculate_multiproof(thread_pool, config, proof_targets.clone());

match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ = state_root_message_sender
.send(StateRootMessage::ProofCalculationError(error));
}
}
});
}
Expand Down Expand Up @@ -517,6 +528,7 @@ where
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool,
);
}
StateRootMessage::StateUpdate(update) => {
Expand All @@ -540,6 +552,7 @@ where
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool,
);
}
StateRootMessage::FinishedStateUpdates => {
Expand Down Expand Up @@ -717,26 +730,23 @@ fn get_proof_targets(
/// Calculate multiproof for the targets.
#[inline]
fn calculate_multiproof<Factory>(
thread_pool: &rayon::ThreadPool,
config: StateRootConfig<Factory>,
proof_targets: MultiProofTargets,
) -> ProviderResult<MultiProof>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider,
Factory:
DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
{
let provider = config.consistent_view.provider_ro()?;

Ok(Proof::from_tx(provider.tx_ref())
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&config.nodes_sorted,
))
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&config.state_sorted,
))
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
.with_branch_node_hash_masks(true)
.multiproof(proof_targets)?)
Ok(ParallelProof::new(
config.consistent_view,
config.nodes_sorted,
config.state_sorted,
config.prefix_sets,
thread_pool,
)
.with_branch_node_hash_masks(true)
.multiproof(proof_targets)?)
}

/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
Expand Down Expand Up @@ -967,8 +977,17 @@ mod tests {
),
config.prefix_sets.clone(),
);
let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));

let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");

let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let task = StateRootTask::new(config, blinded_provider_factory, &state_root_task_pool);
let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope);

Expand Down
2 changes: 1 addition & 1 deletion crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ where
);

// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);

let mut storage_proofs =
B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
Expand Down

0 comments on commit 0a0a2d4

Please sign in to comment.