Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(engine): use ParallelProof::multiproof in StateRootTask #13260

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/engine/tree/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ reth-prune.workspace = true
reth-revm.workspace = true
reth-stages-api.workspace = true
reth-tasks.workspace = true
reth-trie-db.workspace = true
reth-trie-parallel.workspace = true
reth-trie-sparse.workspace = true
reth-trie.workspace = true
Expand Down Expand Up @@ -82,6 +81,7 @@ reth-stages = { workspace = true, features = ["test-utils"] }
reth-static-file.workspace = true
reth-testing-utils.workspace = true
reth-tracing.workspace = true
reth-trie-db.workspace = true

# alloy
alloy-rlp.workspace = true
Expand Down Expand Up @@ -120,6 +120,6 @@ test-utils = [
"reth-static-file",
"reth-tracing",
"reth-trie/test-utils",
"reth-prune-types?/test-utils",
"reth-trie-db/test-utils",
"reth-prune-types?/test-utils",
]
36 changes: 32 additions & 4 deletions crates/engine/tree/benches/state_root_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,34 @@ fn bench_state_root(c: &mut Criterion) {
let nodes_sorted = config.nodes_sorted.clone();
let state_sorted = config.state_sorted.clone();
let prefix_sets = config.prefix_sets.clone();

(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)
let num_threads = std::thread::available_parallelism()
.map_or(1, |num| (num.get() / 2).max(1));

let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");

(
config,
state_updates,
provider,
nodes_sorted,
state_sorted,
prefix_sets,
state_root_task_pool,
)
},
|(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)| {
|(
config,
state_updates,
provider,
nodes_sorted,
state_sorted,
prefix_sets,
state_root_task_pool,
)| {
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
Expand All @@ -162,7 +186,11 @@ fn bench_state_root(c: &mut Criterion) {
);

black_box(std::thread::scope(|scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let task = StateRootTask::new(
config,
blinded_provider_factory,
&state_root_task_pool,
);
let mut hook = task.state_hook();
let handle = task.spawn(scope);

Expand Down
99 changes: 59 additions & 40 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@ use rayon::iter::{ParallelBridge, ParallelIterator};
use reth_errors::{ProviderError, ProviderResult};
use reth_evm::system_calls::OnStateHook;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
StateCommitmentProvider,
providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, StateCommitmentProvider,
};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::Proof,
trie_cursor::InMemoryTrieCursorFactory,
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles,
TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::ParallelStateRootError;
use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError};
use reth_trie_sparse::{
blinded::{BlindedProvider, BlindedProviderFactory},
errors::{SparseStateTrieError, SparseStateTrieResult, SparseTrieError, SparseTrieErrorKind},
Expand Down Expand Up @@ -269,7 +264,7 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)]
pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> {
/// Task configuration.
config: StateRootConfig<Factory>,
/// Receiver for state root related messages.
Expand All @@ -283,10 +278,12 @@ pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
/// The sparse trie used for the state root calculation. If [`None`], then update is in
/// progress.
sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
/// Reference to the shared thread pool for parallel proof generation
thread_pool: &'env rayon::ThreadPool,
}

#[allow(dead_code)]
impl<'env, Factory, ABP, SBP, BPF> StateRootTask<Factory, BPF>
impl<'env, Factory, ABP, SBP, BPF> StateRootTask<'env, Factory, BPF>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
Expand All @@ -302,7 +299,11 @@ where
+ 'env,
{
/// Creates a new state root task with the unified message channel
pub fn new(config: StateRootConfig<Factory>, blinded_provider: BPF) -> Self {
pub fn new(
config: StateRootConfig<Factory>,
blinded_provider: BPF,
thread_pool: &'env rayon::ThreadPool,
) -> Self {
let (tx, rx) = channel();

Self {
Expand All @@ -312,6 +313,7 @@ where
fetched_proof_targets: Default::default(),
proof_sequencer: ProofSequencer::new(),
sparse_trie: Some(Box::new(SparseStateTrie::new(blinded_provider).with_updates(true))),
thread_pool,
}
}

Expand Down Expand Up @@ -350,6 +352,7 @@ where
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
) {
let proof_targets =
targets.into_iter().map(|address| (keccak256(address), Default::default())).collect();
Expand All @@ -362,6 +365,7 @@ where
proof_targets,
proof_sequence_number,
state_root_message_sender,
thread_pool,
);
}

Expand All @@ -375,6 +379,7 @@ where
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
) {
let hashed_state_update = evm_state_to_hashed_post_state(update);

Expand All @@ -388,6 +393,7 @@ where
proof_targets,
proof_sequence_number,
state_root_message_sender,
thread_pool,
);
}

Expand All @@ -398,22 +404,27 @@ where
proof_targets: MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
) {
// Dispatch proof gathering for this state update
scope.spawn(move |_| match calculate_multiproof(config, proof_targets.clone()) {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ =
state_root_message_sender.send(StateRootMessage::ProofCalculationError(error));
scope.spawn(move |_| {
let result = calculate_multiproof(thread_pool, config, proof_targets.clone());

match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ = state_root_message_sender
.send(StateRootMessage::ProofCalculationError(error));
}
}
});
}
Expand Down Expand Up @@ -517,6 +528,7 @@ where
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool,
);
}
StateRootMessage::StateUpdate(update) => {
Expand All @@ -540,6 +552,7 @@ where
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool,
);
}
StateRootMessage::FinishedStateUpdates => {
Expand Down Expand Up @@ -717,26 +730,23 @@ fn get_proof_targets(
/// Calculate multiproof for the targets.
#[inline]
fn calculate_multiproof<Factory>(
fgimenez marked this conversation as resolved.
Show resolved Hide resolved
thread_pool: &rayon::ThreadPool,
config: StateRootConfig<Factory>,
proof_targets: MultiProofTargets,
) -> ProviderResult<MultiProof>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider,
Factory:
DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
{
let provider = config.consistent_view.provider_ro()?;

Ok(Proof::from_tx(provider.tx_ref())
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&config.nodes_sorted,
))
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&config.state_sorted,
))
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
.with_branch_node_hash_masks(true)
.multiproof(proof_targets)?)
Ok(ParallelProof::new(
config.consistent_view,
config.nodes_sorted,
config.state_sorted,
config.prefix_sets,
thread_pool,
)
.with_branch_node_hash_masks(true)
.multiproof(proof_targets)?)
}

/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
Expand Down Expand Up @@ -967,8 +977,17 @@ mod tests {
),
config.prefix_sets.clone(),
);
let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));

let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");

let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let task = StateRootTask::new(config, blinded_provider_factory, &state_root_task_pool);
let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope);

Expand Down
2 changes: 1 addition & 1 deletion crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ where
);

// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);

let mut storage_proofs =
B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
Expand Down
Loading