diff --git a/bin/prover-client/src/errors.rs b/bin/prover-client/src/errors.rs index 4d361949c..c44414774 100644 --- a/bin/prover-client/src/errors.rs +++ b/bin/prover-client/src/errors.rs @@ -50,6 +50,10 @@ pub enum ProvingTaskError { #[error("Witness not found")] WitnessNotFound, + /// Occurs when a newly created proving task is expected but none is found. + #[error("No tasks found after creation; at least one was expected")] + NoTasksFound, + /// Occurs when the witness data provided is invalid. #[error("{0}")] InvalidWitness(String), diff --git a/bin/prover-client/src/operators/checkpoint.rs b/bin/prover-client/src/operators/checkpoint.rs index 1e2c70845..654e532e9 100644 --- a/bin/prover-client/src/operators/checkpoint.rs +++ b/bin/prover-client/src/operators/checkpoint.rs @@ -71,15 +71,14 @@ impl CheckpointOperator { .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; let headers = l2_headers.ok_or_else(|| { - error!(%block_num, "No L2 headers found at block height"); - ProvingTaskError::WitnessNotFound + error!(%block_num, "Failed to fetch L2 block"); + ProvingTaskError::InvalidWitness(format!("Invalid L2 block height {}", block_num)) })?; let first_header: Buf32 = headers .first() .ok_or_else(|| { - error!(%block_num, "Empty L2 headers response"); - ProvingTaskError::InvalidWitness("Invalid block height".to_string()) + ProvingTaskError::InvalidWitness(format!("Invalid L2 block height {}", block_num)) })? .block_id .into(); @@ -116,7 +115,10 @@ impl ProvingOp for CheckpointOperator { .l1_batch_operator .create_task(checkpoint_info.l1_range, task_tracker.clone(), db) .await?; - let l1_batch_id = l1_batch_keys.first().expect("at least one").context(); + let l1_batch_id = l1_batch_keys + .first() + .ok_or_else(|| ProvingTaskError::NoTasksFound)? + .context(); // Doing the manual block idx to id transformation. Will be removed once checkpoint_info // include the range interms of block_id. @@ -130,7 +132,10 @@ impl ProvingOp for CheckpointOperator { .create_task(l2_range, task_tracker.clone(), db) .await?; - let l2_batch_id = l2_batch_keys.first().expect("at least one").context(); + let l2_batch_id = l2_batch_keys + .first() + .ok_or_else(|| ProvingTaskError::NoTasksFound)? + .context(); let deps = vec![*l1_batch_id, *l2_batch_id]; diff --git a/bin/prover-client/src/operators/cl_agg.rs b/bin/prover-client/src/operators/cl_agg.rs index 1d4f701ea..4b0078e31 100644 --- a/bin/prover-client/src/operators/cl_agg.rs +++ b/bin/prover-client/src/operators/cl_agg.rs @@ -6,6 +6,7 @@ use strata_proofimpl_cl_agg::{ClAggInput, ClAggProver}; use strata_rocksdb::prover::db::ProofDb; use strata_state::id::L2BlockId; use tokio::sync::Mutex; +use tracing::error; use super::{cl_stf::ClStfOperator, ProvingOp}; use crate::{errors::ProvingTaskError, hosts, task_tracker::TaskTracker}; @@ -40,8 +41,16 @@ impl ProvingOp for ClAggOperator { ) -> Result, ProvingTaskError> { let mut cl_stf_deps = Vec::with_capacity(batches.len()); - let start_blkid = batches.first().expect("Proof request with empty batch").0; - let end_blkid = batches.last().expect("Proof request with empty batch").1; + // Extract first and last block IDs from batches, error if empty + let (start_blkid, end_blkid) = match (batches.first(), batches.last()) { + (Some(first), Some(last)) => (first.0, last.1), + _ => { + error!("Aggregation task with empty batch"); + return Err(ProvingTaskError::InvalidInput( + "Aggregation task with empty batch".into(), + )); + } + }; let cl_agg_proof_id = ProofContext::ClAgg(start_blkid, end_blkid); diff --git a/bin/prover-client/src/operators/cl_stf.rs b/bin/prover-client/src/operators/cl_stf.rs index 27cdcb153..7a42014e4 100644 --- a/bin/prover-client/src/operators/cl_stf.rs +++ b/bin/prover-client/src/operators/cl_stf.rs @@ -10,11 +10,12 @@ use strata_primitives::{ use strata_proofimpl_cl_stf::prover::{ClStfInput, ClStfProver}; use strata_rocksdb::prover::db::ProofDb; use strata_rpc_api::StrataApiClient; +use strata_rpc_types::RpcBlockHeader; use strata_state::id::L2BlockId; use tokio::sync::Mutex; use tracing::error; -use super::{evm_ee::EvmEeOperator, ProvingOp}; +use super::{constants::MAX_PROVING_BLOCK_RANGE, evm_ee::EvmEeOperator, ProvingOp}; use crate::{errors::ProvingTaskError, hosts, task_tracker::TaskTracker}; /// A struct that implements the [`ProvingOp`] trait for Consensus Layer (CL) State Transition @@ -49,32 +50,62 @@ impl ClStfOperator { } } - pub async fn get_exec_id(&self, cl_block_id: L2BlockId) -> Result { + async fn get_l2_block_header( + &self, + blkid: L2BlockId, + ) -> Result { let header = self .cl_client - .get_header_by_id(cl_block_id) + .get_header_by_id(blkid) .await - .inspect_err(|_| error!(%cl_block_id, "Failed to fetch corresponding ee data")) + .inspect_err(|_| error!(%blkid, "Failed to fetch corresponding ee data")) .map_err(|e| ProvingTaskError::RpcError(e.to_string()))? - .expect("invalid height"); + .ok_or_else(|| { + error!(%blkid, "L2 Block not found"); + ProvingTaskError::InvalidWitness(format!("L2 Block {} not found", blkid)) + })?; + + Ok(header) + } + /// Retrieves the evm_ee block hash corresponding to the given L2 block ID + pub async fn get_exec_id(&self, cl_block_id: L2BlockId) -> Result { + let header = self.get_l2_block_header(cl_block_id).await?; let block = self.evm_ee_operator.get_block(header.block_idx).await?; Ok(block.header.hash.into()) } - /// Retrieves the previous [`L2BlockId`] for the given `L2BlockId` - pub async fn get_prev_block_id( + /// Retrieves the specified number of ancestor block IDs for the given block ID. + pub async fn get_block_ancestors( &self, - block_id: L2BlockId, - ) -> Result { + blkid: L2BlockId, + n_ancestors: u64, + ) -> Result, ProvingTaskError> { + let mut ancestors = Vec::with_capacity(n_ancestors as usize); + let mut blkid = blkid; + for _ in 0..=n_ancestors { + blkid = self.get_prev_block_id(blkid).await?; + ancestors.push(blkid); + } + Ok(ancestors) + } + + /// Retrieves the previous [`L2BlockId`] for the given `L2BlockId` + pub async fn get_prev_block_id(&self, blkid: L2BlockId) -> Result { let l2_block = self .cl_client - .get_header_by_id(block_id) + .get_header_by_id(blkid) .await - .inspect_err(|_| error!(%block_id, "Failed to fetch l2_header")) + .inspect_err(|_| error!(%blkid, "Failed to fetch l2_header")) .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; - let prev_block: Buf32 = l2_block.expect("invalid height").prev_block.into(); + let prev_block: Buf32 = l2_block + .ok_or_else(|| { + error!(%blkid, "L2 Block not found"); + ProvingTaskError::InvalidWitness(format!("L2 Block {} not found", blkid)) + })? + .prev_block + .into(); Ok(prev_block.into()) } @@ -106,7 +137,7 @@ impl ProvingOp for ClStfOperator { let evm_ee_id = evm_ee_tasks .first() - .expect("creation of task should result on at least one key") + .ok_or_else(|| ProvingTaskError::NoTasksFound)? .context(); let cl_stf_id = ProofContext::ClStf(start_block_id, end_block_id); @@ -128,24 +159,30 @@ impl ProvingOp for ClStfOperator { _ => return Err(ProvingTaskError::InvalidInput("CL_STF".to_string())), }; + let start_block = self.get_l2_block_header(start_block_hash).await?; + let end_block = self.get_l2_block_header(end_block_hash).await?; + let num_blocks = end_block.block_idx - start_block.block_idx; + if num_blocks > MAX_PROVING_BLOCK_RANGE { + return Err(ProvingTaskError::InvalidInput(format!( + "Block range exceeds maximum limit {:?}", + task_id.context() + ))); + } + + // Get ancestor blocks and reverse to oldest-first order + let mut l2_block_ids = self.get_block_ancestors(end_block_hash, num_blocks).await?; + l2_block_ids.reverse(); + let mut stf_witness_payloads = Vec::new(); - let mut blkid = end_block_hash; - loop { + for l2_block_id in l2_block_ids { let raw_witness: Option> = self .cl_client - .get_cl_block_witness_raw(blkid) + .get_cl_block_witness_raw(l2_block_id) .await .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; let witness = raw_witness.ok_or(ProvingTaskError::WitnessNotFound)?; stf_witness_payloads.push(witness); - - if blkid == start_block_hash { - break; - } else { - blkid = self.get_prev_block_id(blkid).await?; - } } - stf_witness_payloads.reverse(); let evm_ee_ids = db .get_proof_deps(*task_id.context()) @@ -153,7 +190,7 @@ impl ProvingOp for ClStfOperator { .ok_or(ProvingTaskError::DependencyNotFound(*task_id))?; let evm_ee_id = evm_ee_ids .first() - .expect("should have at least a dependency"); + .ok_or_else(|| ProvingTaskError::NoTasksFound)?; let evm_ee_key = ProofKey::new(*evm_ee_id, *task_id.host()); let evm_ee_proof = db .get_proof(evm_ee_key) @@ -164,8 +201,6 @@ impl ProvingOp for ClStfOperator { let rollup_params = self.rollup_params.as_ref().clone(); Ok(ClStfInput { rollup_params, - // pre_state, - // l2_block, stf_witness_payloads, evm_ee_proof, evm_ee_vk, diff --git a/bin/prover-client/src/operators/constants.rs b/bin/prover-client/src/operators/constants.rs new file mode 100644 index 000000000..63b4d5b93 --- /dev/null +++ b/bin/prover-client/src/operators/constants.rs @@ -0,0 +1,6 @@ +/// Maximum number of blocks allowed in a proving range. +/// +/// This constant serves as a safety limit to prevent proving tasks from processing +/// an excessively large number of blocks. If the number of blocks to prove exceeds +/// this limit, the task will be aborted early. +pub const MAX_PROVING_BLOCK_RANGE: u64 = 1024; diff --git a/bin/prover-client/src/operators/mod.rs b/bin/prover-client/src/operators/mod.rs index a5b074781..5f1d0a710 100644 --- a/bin/prover-client/src/operators/mod.rs +++ b/bin/prover-client/src/operators/mod.rs @@ -30,6 +30,7 @@ pub mod btc; pub mod checkpoint; pub mod cl_agg; pub mod cl_stf; +mod constants; pub mod evm_ee; pub mod l1_batch; pub mod operator; diff --git a/functional-tests/fn_prover_ckp.py b/functional-tests/fn_prover_ckp.py deleted file mode 100644 index 52978eacb..000000000 --- a/functional-tests/fn_prover_ckp.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - -import flexitest - -import testenv -from utils import wait_for_proof_with_time_out - - -@flexitest.register -class ProverClientTest(testenv.StrataTester): - def __init__(self, ctx: flexitest.InitContext): - ctx.set_env("prover") - - def main(self, ctx: flexitest.RunContext): - time_out = 10 * 60 diff --git a/functional-tests/fn_prover_cl_dispatch.py b/functional-tests/fn_prover_cl_dispatch.py index cd7b1002e..fb7e73fa6 100644 --- a/functional-tests/fn_prover_cl_dispatch.py +++ b/functional-tests/fn_prover_cl_dispatch.py @@ -3,7 +3,7 @@ import flexitest import testenv -from utils import wait_for_proof_with_time_out +from utils import cl_slot_to_block_id, wait_for_proof_with_time_out # Parameters defining the range of Execution Engine (EE) blocks to be proven. CL_PROVER_PARAMS = { @@ -28,8 +28,8 @@ def main(self, ctx: flexitest.RunContext): time.sleep(5) # Dispatch the prover task - start_block_id = self.blockidx_2_blockid(seqrpc, CL_PROVER_PARAMS["start_block"]) - end_block_id = self.blockidx_2_blockid(seqrpc, CL_PROVER_PARAMS["end_block"]) + start_block_id = cl_slot_to_block_id(seqrpc, CL_PROVER_PARAMS["start_block"]) + end_block_id = cl_slot_to_block_id(seqrpc, CL_PROVER_PARAMS["end_block"]) task_ids = prover_client_rpc.dev_strata_proveClBlocks((start_block_id, end_block_id)) task_id = task_ids[0] @@ -39,7 +39,3 @@ def main(self, ctx: flexitest.RunContext): time_out = 10 * 60 wait_for_proof_with_time_out(prover_client_rpc, task_id, time_out=time_out) - - def blockidx_2_blockid(self, seqrpc, blockidx): - l2_blks = seqrpc.strata_getHeadersAtIdx(blockidx) - return l2_blks[0]["block_id"] diff --git a/functional-tests/fn_prover_el_dispatch.py b/functional-tests/fn_prover_el_dispatch.py index 374b5f91a..4c247dda9 100644 --- a/functional-tests/fn_prover_el_dispatch.py +++ b/functional-tests/fn_prover_el_dispatch.py @@ -3,7 +3,7 @@ import flexitest import testenv -from utils import wait_for_proof_with_time_out +from utils import el_slot_to_block_id, wait_for_proof_with_time_out # Parameters defining the range of Execution Engine (EE) blocks to be proven. EE_PROVER_PARAMS = { @@ -27,8 +27,8 @@ def main(self, ctx: flexitest.RunContext): time.sleep(5) # Dispatch the prover task - start_block_id = self.blockidx_2_blockid(rethrpc, EE_PROVER_PARAMS["start_block"]) - end_block_id = self.blockidx_2_blockid(rethrpc, EE_PROVER_PARAMS["end_block"]) + start_block_id = el_slot_to_block_id(rethrpc, EE_PROVER_PARAMS["start_block"]) + end_block_id = el_slot_to_block_id(rethrpc, EE_PROVER_PARAMS["end_block"]) task_ids = prover_client_rpc.dev_strata_proveElBlocks((start_block_id, end_block_id)) self.debug(f"got task ids: {task_ids}") @@ -38,6 +38,3 @@ def main(self, ctx: flexitest.RunContext): time_out = 10 * 60 wait_for_proof_with_time_out(prover_client_rpc, task_id, time_out=time_out) - - def blockidx_2_blockid(self, rethrpc, blockid): - return rethrpc.eth_getBlockByNumber(hex(blockid), False)["hash"] diff --git a/functional-tests/utils.py b/functional-tests/utils.py index add7f2485..28a8a0838 100644 --- a/functional-tests/utils.py +++ b/functional-tests/utils.py @@ -493,3 +493,14 @@ def submit_da_blob(btcrpc: BitcoindClient, seqrpc: JsonrpcClient, blobdata: str) timeout=10, ) return tx + + +def cl_slot_to_block_id(seqrpc, slot): + """Convert L2 slot number to block ID.""" + l2_blocks = seqrpc.strata_getHeadersAtIdx(slot) + return l2_blocks[0]["block_id"] + + +def el_slot_to_block_id(rethrpc, block_num): + """Get EL block hash from block number using Ethereum RPC.""" + return rethrpc.eth_getBlockByNumber(hex(block_num), False)["hash"] diff --git a/provers/tests/src/cl.rs b/provers/tests/src/cl.rs index 9a6011349..db28c61d0 100644 --- a/provers/tests/src/cl.rs +++ b/provers/tests/src/cl.rs @@ -19,8 +19,10 @@ impl ClProofGenerator { } } -impl ProofGenerator for ClProofGenerator { +impl ProofGenerator for ClProofGenerator { type Input = (u64, u64); + type P = ClStfProver; + type H = H; fn get_input(&self, block_range: &(u64, u64)) -> ZkVmResult { // Generate EL proof required for aggregation @@ -61,7 +63,7 @@ impl ProofGenerator for ClProofGenerator { mod tests { use super::*; - fn test_proof(cl_prover: ClProofGenerator) { + fn test_proof(cl_prover: &ClProofGenerator) { let start_height = 1; let end_height = 3; diff --git a/provers/tests/src/el.rs b/provers/tests/src/el.rs index 1f020ba01..9acaa43a4 100644 --- a/provers/tests/src/el.rs +++ b/provers/tests/src/el.rs @@ -15,8 +15,11 @@ impl ElProofGenerator { } } -impl ProofGenerator for ElProofGenerator { +impl ProofGenerator for ElProofGenerator { type Input = (u64, u64); + type P = EvmEeProver; + type H = H; + fn get_input(&self, block_range: &(u64, u64)) -> ZkVmResult { let (start_block, end_block) = block_range; let evm_segment = EvmSegment::initialize_from_saved_ee_data(*start_block, *end_block); @@ -38,7 +41,7 @@ mod tests { use super::*; - fn test_proof(el_prover: ElProofGenerator) { + fn test_proof(el_prover: &ElProofGenerator) { let start_height = 1; let end_height = 2; let _ = el_prover.get_proof(&(start_height, end_height)).unwrap(); diff --git a/provers/tests/src/l2_batch.rs b/provers/tests/src/l2_batch.rs index 22c5a0911..5cc13c3f0 100644 --- a/provers/tests/src/l2_batch.rs +++ b/provers/tests/src/l2_batch.rs @@ -18,8 +18,10 @@ impl L2BatchProofGenerator { } } -impl ProofGenerator for L2BatchProofGenerator { +impl ProofGenerator for L2BatchProofGenerator { type Input = Vec<(u64, u64)>; + type P = ClAggProver; + type H = H; fn get_input(&self, batches: &Self::Input) -> ZkVmResult { let mut batch = Vec::new(); @@ -53,7 +55,7 @@ impl ProofGenerator for L2BatchProofGenerator { mod tests { use super::*; - fn test_proof(cl_agg_prover: L2BatchProofGenerator) { + fn test_proof(cl_agg_prover: &L2BatchProofGenerator) { let _ = cl_agg_prover.get_proof(&vec![(1, 3)]).unwrap(); }