Skip to content

Commit

Permalink
add util for blkidx_to_blkid in functional test
Browse files Browse the repository at this point in the history
  • Loading branch information
MdTeach committed Dec 31, 2024
1 parent 0d77fc2 commit 1a74eaa
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 68 deletions.
4 changes: 4 additions & 0 deletions bin/prover-client/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ pub enum ProvingTaskError {
#[error("Witness not found")]
WitnessNotFound,

/// Occurs when a newly created proving task is expected but none is found.
#[error("No tasks found after creation; at least one was expected")]
NoTasksFound,

/// Occurs when the witness data provided is invalid.
#[error("{0}")]
InvalidWitness(String),
Expand Down
17 changes: 11 additions & 6 deletions bin/prover-client/src/operators/checkpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ impl CheckpointOperator {
.map_err(|e| ProvingTaskError::RpcError(e.to_string()))?;

let headers = l2_headers.ok_or_else(|| {
error!(%block_num, "No L2 headers found at block height");
ProvingTaskError::WitnessNotFound
error!(%block_num, "Failed to fetch L2 block");
ProvingTaskError::InvalidWitness(format!("Invalid L2 block height {}", block_num))
})?;

let first_header: Buf32 = headers
.first()
.ok_or_else(|| {
error!(%block_num, "Empty L2 headers response");
ProvingTaskError::InvalidWitness("Invalid block height".to_string())
ProvingTaskError::InvalidWitness(format!("Invalid L2 block height {}", block_num))
})?
.block_id
.into();
Expand Down Expand Up @@ -116,7 +115,10 @@ impl ProvingOp for CheckpointOperator {
.l1_batch_operator
.create_task(checkpoint_info.l1_range, task_tracker.clone(), db)
.await?;
let l1_batch_id = l1_batch_keys.first().expect("at least one").context();
let l1_batch_id = l1_batch_keys
.first()
.ok_or_else(|| ProvingTaskError::NoTasksFound)?
.context();

// Doing the manual block idx to id transformation. Will be removed once checkpoint_info
// include the range interms of block_id.
Expand All @@ -130,7 +132,10 @@ impl ProvingOp for CheckpointOperator {
.create_task(l2_range, task_tracker.clone(), db)
.await?;

let l2_batch_id = l2_batch_keys.first().expect("at least one").context();
let l2_batch_id = l2_batch_keys
.first()
.ok_or_else(|| ProvingTaskError::NoTasksFound)?
.context();

let deps = vec![*l1_batch_id, *l2_batch_id];

Expand Down
13 changes: 11 additions & 2 deletions bin/prover-client/src/operators/cl_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use strata_proofimpl_cl_agg::{ClAggInput, ClAggProver};
use strata_rocksdb::prover::db::ProofDb;
use strata_state::id::L2BlockId;
use tokio::sync::Mutex;
use tracing::error;

use super::{cl_stf::ClStfOperator, ProvingOp};
use crate::{errors::ProvingTaskError, hosts, task_tracker::TaskTracker};
Expand Down Expand Up @@ -40,8 +41,16 @@ impl ProvingOp for ClAggOperator {
) -> Result<Vec<ProofKey>, ProvingTaskError> {
let mut cl_stf_deps = Vec::with_capacity(batches.len());

let start_blkid = batches.first().expect("Proof request with empty batch").0;
let end_blkid = batches.last().expect("Proof request with empty batch").1;
// Extract first and last block IDs from batches, error if empty
let (start_blkid, end_blkid) = match (batches.first(), batches.last()) {
(Some(first), Some(last)) => (first.0, last.1),
_ => {
error!("Aggregation task with empty batch");
return Err(ProvingTaskError::InvalidInput(
"Aggregation task with empty batch".into(),
));
}
};

let cl_agg_proof_id = ProofContext::ClAgg(start_blkid, end_blkid);

Expand Down
87 changes: 61 additions & 26 deletions bin/prover-client/src/operators/cl_stf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ use strata_primitives::{
use strata_proofimpl_cl_stf::prover::{ClStfInput, ClStfProver};
use strata_rocksdb::prover::db::ProofDb;
use strata_rpc_api::StrataApiClient;
use strata_rpc_types::RpcBlockHeader;
use strata_state::id::L2BlockId;
use tokio::sync::Mutex;
use tracing::error;

use super::{evm_ee::EvmEeOperator, ProvingOp};
use super::{constants::MAX_PROVING_BLOCK_RANGE, evm_ee::EvmEeOperator, ProvingOp};
use crate::{errors::ProvingTaskError, hosts, task_tracker::TaskTracker};

/// A struct that implements the [`ProvingOp`] trait for Consensus Layer (CL) State Transition
Expand Down Expand Up @@ -49,32 +50,62 @@ impl ClStfOperator {
}
}

pub async fn get_exec_id(&self, cl_block_id: L2BlockId) -> Result<Buf32, ProvingTaskError> {
async fn get_l2_block_header(
&self,
blkid: L2BlockId,
) -> Result<RpcBlockHeader, ProvingTaskError> {
let header = self
.cl_client
.get_header_by_id(cl_block_id)
.get_header_by_id(blkid)
.await
.inspect_err(|_| error!(%cl_block_id, "Failed to fetch corresponding ee data"))
.inspect_err(|_| error!(%blkid, "Failed to fetch corresponding ee data"))
.map_err(|e| ProvingTaskError::RpcError(e.to_string()))?
.expect("invalid height");
.ok_or_else(|| {
error!(%blkid, "L2 Block not found");
ProvingTaskError::InvalidWitness(format!("L2 Block {} not found", blkid))
})?;

Ok(header)
}

/// Retrieves the evm_ee block hash corresponding to the given L2 block ID
pub async fn get_exec_id(&self, cl_block_id: L2BlockId) -> Result<Buf32, ProvingTaskError> {
let header = self.get_l2_block_header(cl_block_id).await?;
let block = self.evm_ee_operator.get_block(header.block_idx).await?;
Ok(block.header.hash.into())
}

/// Retrieves the previous [`L2BlockId`] for the given `L2BlockId`
pub async fn get_prev_block_id(
/// Retrieves the specified number of ancestor block IDs for the given block ID.
pub async fn get_block_ancestors(
&self,
block_id: L2BlockId,
) -> Result<L2BlockId, ProvingTaskError> {
blkid: L2BlockId,
n_ancestors: u64,
) -> Result<Vec<L2BlockId>, ProvingTaskError> {
let mut ancestors = Vec::with_capacity(n_ancestors as usize);
let mut blkid = blkid;
for _ in 0..=n_ancestors {
blkid = self.get_prev_block_id(blkid).await?;
ancestors.push(blkid);
}
Ok(ancestors)
}

/// Retrieves the previous [`L2BlockId`] for the given `L2BlockId`
pub async fn get_prev_block_id(&self, blkid: L2BlockId) -> Result<L2BlockId, ProvingTaskError> {
let l2_block = self
.cl_client
.get_header_by_id(block_id)
.get_header_by_id(blkid)
.await
.inspect_err(|_| error!(%block_id, "Failed to fetch l2_header"))
.inspect_err(|_| error!(%blkid, "Failed to fetch l2_header"))
.map_err(|e| ProvingTaskError::RpcError(e.to_string()))?;

let prev_block: Buf32 = l2_block.expect("invalid height").prev_block.into();
let prev_block: Buf32 = l2_block
.ok_or_else(|| {
error!(%blkid, "L2 Block not found");
ProvingTaskError::InvalidWitness(format!("L2 Block {} not found", blkid))
})?
.prev_block
.into();

Ok(prev_block.into())
}
Expand Down Expand Up @@ -106,7 +137,7 @@ impl ProvingOp for ClStfOperator {

let evm_ee_id = evm_ee_tasks
.first()
.expect("creation of task should result on at least one key")
.ok_or_else(|| ProvingTaskError::NoTasksFound)?
.context();

let cl_stf_id = ProofContext::ClStf(start_block_id, end_block_id);
Expand All @@ -128,32 +159,38 @@ impl ProvingOp for ClStfOperator {
_ => return Err(ProvingTaskError::InvalidInput("CL_STF".to_string())),
};

let start_block = self.get_l2_block_header(start_block_hash).await?;
let end_block = self.get_l2_block_header(end_block_hash).await?;
let num_blocks = end_block.block_idx - start_block.block_idx;
if num_blocks > MAX_PROVING_BLOCK_RANGE {
return Err(ProvingTaskError::InvalidInput(format!(
"Block range exceeds maximum limit {:?}",
task_id.context()
)));
}

// Get ancestor blocks and reverse to oldest-first order
let mut l2_block_ids = self.get_block_ancestors(end_block_hash, num_blocks).await?;
l2_block_ids.reverse();

let mut stf_witness_payloads = Vec::new();
let mut blkid = end_block_hash;
loop {
for l2_block_id in l2_block_ids {
let raw_witness: Option<Vec<u8>> = self
.cl_client
.get_cl_block_witness_raw(blkid)
.get_cl_block_witness_raw(l2_block_id)
.await
.map_err(|e| ProvingTaskError::RpcError(e.to_string()))?;
let witness = raw_witness.ok_or(ProvingTaskError::WitnessNotFound)?;
stf_witness_payloads.push(witness);

if blkid == start_block_hash {
break;
} else {
blkid = self.get_prev_block_id(blkid).await?;
}
}
stf_witness_payloads.reverse();

let evm_ee_ids = db
.get_proof_deps(*task_id.context())
.map_err(ProvingTaskError::DatabaseError)?
.ok_or(ProvingTaskError::DependencyNotFound(*task_id))?;
let evm_ee_id = evm_ee_ids
.first()
.expect("should have at least a dependency");
.ok_or_else(|| ProvingTaskError::NoTasksFound)?;
let evm_ee_key = ProofKey::new(*evm_ee_id, *task_id.host());
let evm_ee_proof = db
.get_proof(evm_ee_key)
Expand All @@ -164,8 +201,6 @@ impl ProvingOp for ClStfOperator {
let rollup_params = self.rollup_params.as_ref().clone();
Ok(ClStfInput {
rollup_params,
// pre_state,
// l2_block,
stf_witness_payloads,
evm_ee_proof,
evm_ee_vk,
Expand Down
6 changes: 6 additions & 0 deletions bin/prover-client/src/operators/constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/// Maximum number of blocks allowed in a proving range.
///
/// This constant serves as a safety limit to prevent proving tasks from processing
/// an excessively large number of blocks. If the number of blocks to prove exceeds
/// this limit, the task will be aborted early.
pub const MAX_PROVING_BLOCK_RANGE: u64 = 1024;
1 change: 1 addition & 0 deletions bin/prover-client/src/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod btc;
pub mod checkpoint;
pub mod cl_agg;
pub mod cl_stf;
mod constants;
pub mod evm_ee;
pub mod l1_batch;
pub mod operator;
Expand Down
15 changes: 0 additions & 15 deletions functional-tests/fn_prover_ckp.py

This file was deleted.

10 changes: 3 additions & 7 deletions functional-tests/fn_prover_cl_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import flexitest

import testenv
from utils import wait_for_proof_with_time_out
from utils import cl_slot_to_block_id, wait_for_proof_with_time_out

# Parameters defining the range of Execution Engine (EE) blocks to be proven.
CL_PROVER_PARAMS = {
Expand All @@ -28,8 +28,8 @@ def main(self, ctx: flexitest.RunContext):
time.sleep(5)

# Dispatch the prover task
start_block_id = self.blockidx_2_blockid(seqrpc, CL_PROVER_PARAMS["start_block"])
end_block_id = self.blockidx_2_blockid(seqrpc, CL_PROVER_PARAMS["end_block"])
start_block_id = cl_slot_to_block_id(seqrpc, CL_PROVER_PARAMS["start_block"])
end_block_id = cl_slot_to_block_id(seqrpc, CL_PROVER_PARAMS["end_block"])

task_ids = prover_client_rpc.dev_strata_proveClBlocks((start_block_id, end_block_id))
task_id = task_ids[0]
Expand All @@ -39,7 +39,3 @@ def main(self, ctx: flexitest.RunContext):

time_out = 10 * 60
wait_for_proof_with_time_out(prover_client_rpc, task_id, time_out=time_out)

def blockidx_2_blockid(self, seqrpc, blockidx):
l2_blks = seqrpc.strata_getHeadersAtIdx(blockidx)
return l2_blks[0]["block_id"]
9 changes: 3 additions & 6 deletions functional-tests/fn_prover_el_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import flexitest

import testenv
from utils import wait_for_proof_with_time_out
from utils import el_slot_to_block_id, wait_for_proof_with_time_out

# Parameters defining the range of Execution Engine (EE) blocks to be proven.
EE_PROVER_PARAMS = {
Expand All @@ -27,8 +27,8 @@ def main(self, ctx: flexitest.RunContext):
time.sleep(5)

# Dispatch the prover task
start_block_id = self.blockidx_2_blockid(rethrpc, EE_PROVER_PARAMS["start_block"])
end_block_id = self.blockidx_2_blockid(rethrpc, EE_PROVER_PARAMS["end_block"])
start_block_id = el_slot_to_block_id(rethrpc, EE_PROVER_PARAMS["start_block"])
end_block_id = el_slot_to_block_id(rethrpc, EE_PROVER_PARAMS["end_block"])

task_ids = prover_client_rpc.dev_strata_proveElBlocks((start_block_id, end_block_id))
self.debug(f"got task ids: {task_ids}")
Expand All @@ -38,6 +38,3 @@ def main(self, ctx: flexitest.RunContext):

time_out = 10 * 60
wait_for_proof_with_time_out(prover_client_rpc, task_id, time_out=time_out)

def blockidx_2_blockid(self, rethrpc, blockid):
return rethrpc.eth_getBlockByNumber(hex(blockid), False)["hash"]
11 changes: 11 additions & 0 deletions functional-tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,14 @@ def submit_da_blob(btcrpc: BitcoindClient, seqrpc: JsonrpcClient, blobdata: str)
timeout=10,
)
return tx


def cl_slot_to_block_id(seqrpc, slot):
"""Convert L2 slot number to block ID."""
l2_blocks = seqrpc.strata_getHeadersAtIdx(slot)
return l2_blocks[0]["block_id"]


def el_slot_to_block_id(rethrpc, block_num):
"""Get EL block hash from block number using Ethereum RPC."""
return rethrpc.eth_getBlockByNumber(hex(block_num), False)["hash"]
6 changes: 4 additions & 2 deletions provers/tests/src/cl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ impl<H: ZkVmHost> ClProofGenerator<H> {
}
}

impl<H: ZkVmHost> ProofGenerator<ClStfProver> for ClProofGenerator<H> {
impl<H: ZkVmHost> ProofGenerator for ClProofGenerator<H> {
type Input = (u64, u64);
type P = ClStfProver;
type H = H;

fn get_input(&self, block_range: &(u64, u64)) -> ZkVmResult<ClStfInput> {
// Generate EL proof required for aggregation
Expand Down Expand Up @@ -61,7 +63,7 @@ impl<H: ZkVmHost> ProofGenerator<ClStfProver> for ClProofGenerator<H> {
mod tests {
use super::*;

fn test_proof<H: ZkVmHost>(cl_prover: ClProofGenerator<H>) {
fn test_proof<H: ZkVmHost>(cl_prover: &ClProofGenerator<H>) {
let start_height = 1;
let end_height = 3;

Expand Down
7 changes: 5 additions & 2 deletions provers/tests/src/el.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ impl<H: ZkVmHost> ElProofGenerator<H> {
}
}

impl<H: ZkVmHost> ProofGenerator<EvmEeProver> for ElProofGenerator<H> {
impl<H: ZkVmHost> ProofGenerator for ElProofGenerator<H> {
type Input = (u64, u64);
type P = EvmEeProver;
type H = H;

fn get_input(&self, block_range: &(u64, u64)) -> ZkVmResult<EvmEeProofInput> {
let (start_block, end_block) = block_range;
let evm_segment = EvmSegment::initialize_from_saved_ee_data(*start_block, *end_block);
Expand All @@ -38,7 +41,7 @@ mod tests {

use super::*;

fn test_proof<H: ZkVmHost>(el_prover: ElProofGenerator<H>) {
fn test_proof<H: ZkVmHost>(el_prover: &ElProofGenerator<H>) {
let start_height = 1;
let end_height = 2;
let _ = el_prover.get_proof(&(start_height, end_height)).unwrap();
Expand Down
Loading

0 comments on commit 1a74eaa

Please sign in to comment.