From 5673a2b7259b1967b3f58f1ea75b77e9712e1b42 Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 5 Dec 2024 10:31:56 +0800 Subject: [PATCH 1/8] draft --- ceno_zkvm/src/scheme.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 73 +++++++++++++++++++------------ ceno_zkvm/src/scheme/verifier.rs | 37 +++++++++------- ceno_zkvm/src/witness.rs | 39 +++++++++++++++++ mpcs/src/lib.rs | 3 ++ multilinear_extensions/src/mle.rs | 2 + 6 files changed, 113 insertions(+), 43 deletions(-) diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 3143f2287..62af3ab96 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -122,7 +122,7 @@ pub struct ZKVMProof> { pub raw_pi: Vec>, // the evaluation of raw_pi. pub pi_evals: Vec, - opcode_proofs: BTreeMap)>, + opcode_proofs: BTreeMap>)>, table_proofs: BTreeMap)>, } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index e5ae01b8a..69e22380f 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use core::assert_eq; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, sync::Arc, @@ -8,7 +9,7 @@ use ff::Field; use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - mle::{IntoMLE, MultilinearExtension}, + mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, @@ -35,7 +36,7 @@ use crate::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, - virtual_polys::VirtualPolynomials, + virtual_polys::VirtualPolynomials, witness::RowMajorMatrix, }; use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; @@ -90,34 +91,44 @@ impl> ZKVMProver { } exit_span!(span); + + let chunk_size = 1048576; + // commit to main traces - let mut commitments = BTreeMap::new(); - let mut wits = BTreeMap::new(); + let mut wits_and_commitments: BTreeMap, Vec>, PCS::CommitmentWithData)>> = BTreeMap::new(); let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, witness) in witnesses.into_iter_sorted() { let num_instances = witness.num_instances(); + tracing::warn!("committing {} witnesses of size {}..", circuit_name, num_instances); + if num_instances == 0 { + wits_and_commitments.insert(circuit_name.clone(), Vec::new()); + continue; + } let span = entered_span!( "commit to iteration", circuit_name = circuit_name, profiling_2 = true ); - let witness = match num_instances { - 0 => vec![], - _ => { - let witness = witness.into_mles(); - commitments.insert( - circuit_name.clone(), - PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) - .map_err(ZKVMError::PCSError)?, - ); - witness - } - }; + + let witness_chunks = witness.chunk_by_num(chunk_size); + if witness_chunks.len() > 1 { + tracing::warn!("split {circuit_name} witness into {} chunks", witness_chunks.len()); + } + let witness_and_commitment: Vec<_> = witness_chunks.into_iter().map(|witness| -> Result<_, ZKVMError> { + // TODO: should we store the mle result? + tracing::debug!("into mle {}", witness.num_instances()); + let witness_mles = witness.clone().into_mles(); + tracing::debug!("batch_commit_and_write"); + let commitment = PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript).map_err(ZKVMError::PCSError)?; + tracing::debug!("done"); + let arc_mles = witness_mles.into_iter().map(|v| v.into()).collect_vec(); + Ok((witness, arc_mles, commitment)) + }).collect::, _>>()?; + wits_and_commitments.insert(circuit_name.clone(), witness_and_commitment); exit_span!(span); - wits.insert(circuit_name, (witness, num_instances)); - } + }; exit_span!(commit_to_traces_span); // squeeze two challenges from transcript @@ -135,13 +146,14 @@ impl> ZKVMProver { .iter() // Sorted by key. .zip_eq(transcripts.iter_mut().enumerate()) { - let (witness, num_instances) = wits + let mut witness_and_wit: Vec<_> = wits_and_commitments .remove(circuit_name) .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; - if witness.is_empty() { + + if witness_and_wit.is_empty() { continue; } - let wits_commit = commitments.remove(circuit_name).unwrap(); + // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); let is_opcode_circuit = cs.lk_table_expressions.is_empty() @@ -160,11 +172,13 @@ impl> ZKVMProver { for lk_s in &cs.lk_expressions_namespace_map { tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s); } - let opcode_proof = self.create_opcode_proof( + let opcode_proof: Vec<_> = witness_and_wit.into_iter().enumerate().map(|(idx, (witness, arc_mles, wits_commit))| -> Result<_, ZKVMError> { + let num_instances = witness.num_instances(); + let proof = self.create_opcode_proof( circuit_name, &self.pk.pp, pk, - witness.into_iter().map(|w| w.into()).collect_vec(), + arc_mles.into_iter().map(|v| v.into()).collect_vec(), wits_commit, &pi, num_instances, @@ -172,19 +186,24 @@ impl> ZKVMProver { &challenges, )?; tracing::info!( - "generated proof for opcode {} with num_instances={}", - circuit_name, - num_instances + "generated proof for opcode {} with num_instances={}, chunk idx {idx}", + circuit_name, num_instances ); + Ok(proof) + }).collect::, _>>()?; + vm_proof .opcode_proofs .insert(circuit_name.clone(), (i, opcode_proof)); } else { + assert_eq!(witness_and_wit.len(), 1); + let (witness, arc_mles, wits_commit) = witness_and_wit.remove(0); + let num_instances = witness.num_instances(); let (table_proof, pi_in_evals) = self.create_table_proof( circuit_name, &self.pk.pp, pk, - witness.into_iter().map(|v| v.into()).collect_vec(), + arc_mles.into_iter().map(|v| v.into()).collect_vec(), wits_commit, &pi, transcript, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 78f70793f..d476a5327 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -61,10 +61,11 @@ impl> ZKVMVerifier does_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. + // seems a bit adhoc here.. let num_instances = vm_proof .opcode_proofs .get(&HaltInstruction::::name()) - .map(|(_, p)| p.num_instances) + .map(|(_, p)| p[0].num_instances) .unwrap_or(0); if num_instances != (does_halt as usize) { return Err(ZKVMError::VerifyError(format!( @@ -119,8 +120,10 @@ impl> ZKVMVerifier for (name, (_, proof)) in vm_proof.opcode_proofs.iter() { tracing::debug!("read {}'s commit", name); - PCS::write_commitment(&proof.wits_commit, &mut transcript) + for p in proof { + PCS::write_commitment(&p.wits_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; + } } for (name, (_, proof)) in vm_proof.table_proofs.iter() { tracing::debug!("read {}'s commit", name); @@ -140,7 +143,7 @@ impl> ZKVMVerifier let point_eval = PointAndEval::default(); let mut transcripts = transcript.fork(self.vk.circuit_vks.len()); - for (name, (i, opcode_proof)) in vm_proof.opcode_proofs { + for (name, (i, opcode_proofs)) in vm_proof.opcode_proofs { let transcript = &mut transcripts[i]; let circuit_vk = self @@ -148,19 +151,22 @@ impl> ZKVMVerifier .circuit_vks .get(&name) .ok_or(ZKVMError::VKNotFound(name.clone()))?; - let _rand_point = self.verify_opcode_proof( - &name, - &self.vk.vp, - circuit_vk, - &opcode_proof, - pi_evals, - transcript, - NUM_FANIN, - &point_eval, - &challenges, - )?; + for opcode_proof in &opcode_proofs { + let _rand_point = self.verify_opcode_proof( + &name, + &self.vk.vp, + circuit_vk, + &opcode_proof, + pi_evals, + transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; + } + tracing::info!("verified proof for opcode {}", name); - + for opcode_proof in &opcode_proofs { // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().lk_expressions.len(); let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; @@ -177,6 +183,7 @@ impl> ZKVMVerifier opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); logup_sum += opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); + } } for (name, (i, table_proof)) in vm_proof.table_proofs { diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 7acc9ad50..2c05533bd 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -1,4 +1,5 @@ use ff::Field; +use core::assert_eq; use std::{ array, cell::RefCell, @@ -58,7 +59,13 @@ impl RowMajorMatrix { } } + + pub fn num_col(&self) -> usize { + self.num_col + } + pub fn num_instances(&self) -> usize { + tracing::info!("num_instances... {} {} {}", self.values.len(), self.num_col, self.num_padding_rows); self.values.len() / self.num_col - self.num_padding_rows } @@ -96,6 +103,7 @@ impl RowMajorMatrix { } pub fn de_interleaving(mut self) -> Vec> { + tracing::debug!("de_interleaving.."); (0..self.num_col) .map(|i| { self.values @@ -107,6 +115,37 @@ impl RowMajorMatrix { }) .collect() } + + // TODO: should we consume or clone `self`? + pub fn chunk_by_num(&self, instance_num_per_chunk: usize) -> Vec { + let chunk_num = (self.num_instances() + instance_num_per_chunk - 1) / instance_num_per_chunk; + let mut result = Vec::new(); + let mut offset = 0; + for i in 0..chunk_num { + let num_rows = if i < chunk_num - 1 { + instance_num_per_chunk + } else { + self.num_instances() - instance_num_per_chunk * i + }; + let mut values: Vec<_> = self.values[offset..offset + num_rows].to_vec(); + offset += num_rows; + let num_total_rows = next_pow2_instance_padding(num_rows); + //unsafe { values.resize(num_total_rows, MaybeUninit::uninit()) }; + unsafe { values.set_len(num_total_rows * self.num_col) }; + let num_padding_rows = num_total_rows - num_rows; + tracing::info!("chunk_by_num {i}th chunk: num_rows {num_rows}, num_total_rows {num_total_rows}, num_padding_rows {num_padding_rows}"); + result.push(Self { + num_col: self.num_col, + num_padding_rows, + values, + }); + } + assert_eq!(self.num_instances(), result.iter().map(|c| { + tracing::info!("num_instances {}", c.num_instances()); + c.num_instances() + }).sum::()); + result + } } impl RowMajorMatrix { diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index b6716c19e..d012850b2 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -156,8 +156,11 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { polys: &[DenseMultilinearExtension], transcript: &mut Transcript, ) -> Result { + println!("000"); let comm = Self::batch_commit(pp, polys)?; + println!("001"); Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; + println!("002"); Ok(comm) } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 8a182e645..8593f33f0 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -105,6 +105,7 @@ pub trait IntoMLE: Sized { impl IntoMLE> for Vec { fn into_mle(mut self) -> DenseMultilinearExtension { let next_pow2 = self.len().next_power_of_two(); + tracing::info!("into_mle with len {next_pow2}"); self.resize(next_pow2, F::ZERO); DenseMultilinearExtension::from_evaluation_vec_smart::(ceil_log2(next_pow2), self) } @@ -118,6 +119,7 @@ impl> IntoMLEs> { fn into_mles(self) -> Vec> { + tracing::info!("vec-vec into_mles"); self.into_iter().map(|v| v.into_mle()).collect() } } From eafc4f878bdd1a56848a87d75832b5e2921cfde7 Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 5 Dec 2024 12:23:59 +0800 Subject: [PATCH 2/8] can run --- ceno_zkvm/src/scheme/prover.rs | 8 ++++-- ceno_zkvm/src/scheme/verifier.rs | 5 ++++ ceno_zkvm/src/witness.rs | 46 +++++++++++++++++------------- multilinear_extensions/src/util.rs | 5 ++-- 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 69e22380f..5bd85f0c8 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -92,7 +92,7 @@ impl> ZKVMProver { exit_span!(span); - let chunk_size = 1048576; + let chunk_size = 1 * 1048576; // commit to main traces let mut wits_and_commitments: BTreeMap, Vec>, PCS::CommitmentWithData)>> = BTreeMap::new(); @@ -118,8 +118,10 @@ impl> ZKVMProver { } let witness_and_commitment: Vec<_> = witness_chunks.into_iter().map(|witness| -> Result<_, ZKVMError> { // TODO: should we store the mle result? - tracing::debug!("into mle {}", witness.num_instances()); - let witness_mles = witness.clone().into_mles(); + tracing::debug!("into mle: {}", witness.num_instances()); + let witness_clone = witness.clone(); + tracing::debug!("cloned"); + let witness_mles = witness_clone.into_mles(); tracing::debug!("batch_commit_and_write"); let commitment = PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript).map_err(ZKVMError::PCSError)?; tracing::debug!("done"); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index d476a5327..c10915ce8 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -478,6 +478,11 @@ impl> ZKVMVerifier } // verify zero expression (degree = 1) statement, thus no sumcheck + for (expr, name) in cs.assert_zero_expressions.iter().zip_eq(cs.assert_zero_expressions_namespace_map.iter()) { + if eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO { + tracing::error!("checking zero expression {name} failed."); + } + } if cs.assert_zero_expressions.iter().any(|expr| { eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO }) { diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 2c05533bd..48784a510 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -41,14 +41,14 @@ macro_rules! set_fixed_val { } #[derive(Clone)] -pub struct RowMajorMatrix { +pub struct RowMajorMatrix { // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance values: Vec>, num_padding_rows: usize, num_col: usize, } -impl RowMajorMatrix { +impl RowMajorMatrix { pub fn new(num_rows: usize, num_col: usize) -> Self { let num_total_rows = next_pow2_instance_padding(num_rows); let num_padding_rows = num_total_rows - num_rows; @@ -65,8 +65,8 @@ impl RowMajorMatrix { } pub fn num_instances(&self) -> usize { - tracing::info!("num_instances... {} {} {}", self.values.len(), self.num_col, self.num_padding_rows); - self.values.len() / self.num_col - self.num_padding_rows + //tracing::info!("num_instances... {} {} {}", self.values.len(), self.num_col, self.num_padding_rows); + (self.values.len() / self.num_col).checked_sub(self.num_padding_rows).expect("overflow") } pub fn num_padding_instances(&self) -> usize { @@ -117,31 +117,36 @@ impl RowMajorMatrix { } // TODO: should we consume or clone `self`? - pub fn chunk_by_num(&self, instance_num_per_chunk: usize) -> Vec { - let chunk_num = (self.num_instances() + instance_num_per_chunk - 1) / instance_num_per_chunk; + pub fn chunk_by_num(&self, chunk_rows: usize) -> Vec { + let padded_row_num = self.values.len() / self.num_col; + if padded_row_num <= chunk_rows { + return vec![self.clone()]; + } + // padded_row_num and instance_num_per_chunk should both be pow of 2. + assert_eq!(padded_row_num % chunk_rows, 0); + let chunk_num = (self.num_instances() + chunk_rows - 1) / chunk_rows; let mut result = Vec::new(); - let mut offset = 0; for i in 0..chunk_num { - let num_rows = if i < chunk_num - 1 { - instance_num_per_chunk - } else { - self.num_instances() - instance_num_per_chunk * i + let mut values: Vec<_> = self.values[(i * chunk_rows * self.num_col)..((i + 1) * chunk_rows*self.num_col)].to_vec(); + let mut num_padding_rows = 0; + + // Only last chunk contains padding rows. + if i == chunk_num - 1 && self.num_instances() % chunk_rows != 0 { + let num_rows = self.num_instances() % chunk_rows; + let num_total_rows = next_pow2_instance_padding(num_rows); + num_padding_rows = num_total_rows - num_rows; + values.truncate(num_total_rows * self.num_col); }; - let mut values: Vec<_> = self.values[offset..offset + num_rows].to_vec(); - offset += num_rows; - let num_total_rows = next_pow2_instance_padding(num_rows); - //unsafe { values.resize(num_total_rows, MaybeUninit::uninit()) }; - unsafe { values.set_len(num_total_rows * self.num_col) }; - let num_padding_rows = num_total_rows - num_rows; - tracing::info!("chunk_by_num {i}th chunk: num_rows {num_rows}, num_total_rows {num_total_rows}, num_padding_rows {num_padding_rows}"); + + tracing::info!("chunk_by_num {i}th chunk: num_rows {chunk_rows}, num_padding_rows {num_padding_rows}"); result.push(Self { num_col: self.num_col, num_padding_rows, values, }); } - assert_eq!(self.num_instances(), result.iter().map(|c| { - tracing::info!("num_instances {}", c.num_instances()); + assert_eq!(self.num_instances(), result.iter().enumerate().map(|(idx, c)| { + tracing::info!("{idx}chunk num_instances: {}", c.num_instances()); c.num_instances() }).sum::()); result @@ -152,6 +157,7 @@ impl RowMajorMatrix { pub fn into_mles>( self, ) -> Vec> { + tracing::info!("before de_interleaving"); self.de_interleaving().into_mles() } } diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index a0a8e56a2..54c68b055 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -20,9 +20,10 @@ pub fn ceil_log2(x: usize) -> usize { usize_bits - (x - 1).leading_zeros() as usize } -pub fn create_uninit_vec(len: usize) -> Vec> { +pub fn create_uninit_vec(len: usize) -> Vec> { let mut vec: Vec> = Vec::with_capacity(len); - unsafe { vec.set_len(len) }; + vec.resize(len, MaybeUninit::uninit()); + //unsafe { vec.set_len(len) }; vec } From f0693f9716eee117815e2630d6d68ac9da7b97ee Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 5 Dec 2024 12:24:50 +0800 Subject: [PATCH 3/8] fmt --- ceno_zkvm/src/scheme/prover.rs | 61 ++++++++++++++++++++---------- ceno_zkvm/src/scheme/verifier.rs | 48 +++++++++++++---------- ceno_zkvm/src/witness.rs | 32 +++++++++++----- multilinear_extensions/src/util.rs | 2 +- 4 files changed, 90 insertions(+), 53 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 5bd85f0c8..c416aadc2 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,5 +1,5 @@ -use ff_ext::ExtensionField; use core::assert_eq; +use ff_ext::ExtensionField; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, sync::Arc, @@ -36,7 +36,8 @@ use crate::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, - virtual_polys::VirtualPolynomials, witness::RowMajorMatrix, + virtual_polys::VirtualPolynomials, + witness::RowMajorMatrix, }; use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; @@ -91,17 +92,27 @@ impl> ZKVMProver { } exit_span!(span); - let chunk_size = 1 * 1048576; // commit to main traces - let mut wits_and_commitments: BTreeMap, Vec>, PCS::CommitmentWithData)>> = BTreeMap::new(); + let mut wits_and_commitments: BTreeMap< + String, + Vec<( + RowMajorMatrix<_>, + Vec>, + PCS::CommitmentWithData, + )>, + > = BTreeMap::new(); let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, witness) in witnesses.into_iter_sorted() { let num_instances = witness.num_instances(); - tracing::warn!("committing {} witnesses of size {}..", circuit_name, num_instances); + tracing::warn!( + "committing {} witnesses of size {}..", + circuit_name, + num_instances + ); if num_instances == 0 { wits_and_commitments.insert(circuit_name.clone(), Vec::new()); continue; @@ -114,23 +125,31 @@ impl> ZKVMProver { let witness_chunks = witness.chunk_by_num(chunk_size); if witness_chunks.len() > 1 { - tracing::warn!("split {circuit_name} witness into {} chunks", witness_chunks.len()); + tracing::warn!( + "split {circuit_name} witness into {} chunks", + witness_chunks.len() + ); } - let witness_and_commitment: Vec<_> = witness_chunks.into_iter().map(|witness| -> Result<_, ZKVMError> { - // TODO: should we store the mle result? - tracing::debug!("into mle: {}", witness.num_instances()); - let witness_clone = witness.clone(); - tracing::debug!("cloned"); - let witness_mles = witness_clone.into_mles(); - tracing::debug!("batch_commit_and_write"); - let commitment = PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript).map_err(ZKVMError::PCSError)?; - tracing::debug!("done"); - let arc_mles = witness_mles.into_iter().map(|v| v.into()).collect_vec(); - Ok((witness, arc_mles, commitment)) - }).collect::, _>>()?; - wits_and_commitments.insert(circuit_name.clone(), witness_and_commitment); + let witness_and_commitment: Vec<_> = witness_chunks + .into_iter() + .map(|witness| -> Result<_, ZKVMError> { + // TODO: should we store the mle result? + tracing::debug!("into mle: {}", witness.num_instances()); + let witness_clone = witness.clone(); + tracing::debug!("cloned"); + let witness_mles = witness_clone.into_mles(); + tracing::debug!("batch_commit_and_write"); + let commitment = + PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript) + .map_err(ZKVMError::PCSError)?; + tracing::debug!("done"); + let arc_mles = witness_mles.into_iter().map(|v| v.into()).collect_vec(); + Ok((witness, arc_mles, commitment)) + }) + .collect::, _>>()?; + wits_and_commitments.insert(circuit_name.clone(), witness_and_commitment); exit_span!(span); - }; + } exit_span!(commit_to_traces_span); // squeeze two challenges from transcript @@ -193,7 +212,7 @@ impl> ZKVMProver { ); Ok(proof) }).collect::, _>>()?; - + vm_proof .opcode_proofs .insert(circuit_name.clone(), (i, opcode_proof)); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index c10915ce8..258ecdff8 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -65,7 +65,7 @@ impl> ZKVMVerifier let num_instances = vm_proof .opcode_proofs .get(&HaltInstruction::::name()) - .map(|(_, p)| p[0].num_instances) + .map(|(_, p)| p[0].num_instances) .unwrap_or(0); if num_instances != (does_halt as usize) { return Err(ZKVMError::VerifyError(format!( @@ -122,7 +122,7 @@ impl> ZKVMVerifier tracing::debug!("read {}'s commit", name); for p in proof { PCS::write_commitment(&p.wits_commit, &mut transcript) - .map_err(ZKVMError::PCSError)?; + .map_err(ZKVMError::PCSError)?; } } for (name, (_, proof)) in vm_proof.table_proofs.iter() { @@ -164,25 +164,25 @@ impl> ZKVMVerifier &challenges, )?; } - + tracing::info!("verified proof for opcode {}", name); for opcode_proof in &opcode_proofs { - // getting the number of dummy padding item that we used in this opcode circuit - let num_lks = circuit_vk.get_cs().lk_expressions.len(); - let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; - let num_padded_instance = - next_pow2_instance_padding(opcode_proof.num_instances) - opcode_proof.num_instances; - dummy_table_item_multiplicity += num_padded_lks_per_instance - * opcode_proof.num_instances - + num_lks.next_power_of_two() * num_padded_instance; - - prod_r *= opcode_proof.record_r_out_evals.iter().product::(); - prod_w *= opcode_proof.record_w_out_evals.iter().product::(); - - logup_sum += - opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); - logup_sum += - opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks = circuit_vk.get_cs().lk_expressions.len(); + let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; + let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances) + - opcode_proof.num_instances; + dummy_table_item_multiplicity += num_padded_lks_per_instance + * opcode_proof.num_instances + + num_lks.next_power_of_two() * num_padded_instance; + + prod_r *= opcode_proof.record_r_out_evals.iter().product::(); + prod_w *= opcode_proof.record_w_out_evals.iter().product::(); + + logup_sum += + opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); + logup_sum += + opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); } } @@ -478,8 +478,14 @@ impl> ZKVMVerifier } // verify zero expression (degree = 1) statement, thus no sumcheck - for (expr, name) in cs.assert_zero_expressions.iter().zip_eq(cs.assert_zero_expressions_namespace_map.iter()) { - if eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO { + for (expr, name) in cs + .assert_zero_expressions + .iter() + .zip_eq(cs.assert_zero_expressions_namespace_map.iter()) + { + if eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) + != E::ZERO + { tracing::error!("checking zero expression {name} failed."); } } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 48784a510..eb91c25e9 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -1,5 +1,5 @@ -use ff::Field; use core::assert_eq; +use ff::Field; use std::{ array, cell::RefCell, @@ -59,14 +59,15 @@ impl RowMajorMatrix { } } - pub fn num_col(&self) -> usize { self.num_col } pub fn num_instances(&self) -> usize { - //tracing::info!("num_instances... {} {} {}", self.values.len(), self.num_col, self.num_padding_rows); - (self.values.len() / self.num_col).checked_sub(self.num_padding_rows).expect("overflow") + // tracing::info!("num_instances... {} {} {}", self.values.len(), self.num_col, self.num_padding_rows); + (self.values.len() / self.num_col) + .checked_sub(self.num_padding_rows) + .expect("overflow") } pub fn num_padding_instances(&self) -> usize { @@ -127,7 +128,9 @@ impl RowMajorMatrix { let chunk_num = (self.num_instances() + chunk_rows - 1) / chunk_rows; let mut result = Vec::new(); for i in 0..chunk_num { - let mut values: Vec<_> = self.values[(i * chunk_rows * self.num_col)..((i + 1) * chunk_rows*self.num_col)].to_vec(); + let mut values: Vec<_> = self.values + [(i * chunk_rows * self.num_col)..((i + 1) * chunk_rows * self.num_col)] + .to_vec(); let mut num_padding_rows = 0; // Only last chunk contains padding rows. @@ -138,17 +141,26 @@ impl RowMajorMatrix { values.truncate(num_total_rows * self.num_col); }; - tracing::info!("chunk_by_num {i}th chunk: num_rows {chunk_rows}, num_padding_rows {num_padding_rows}"); + tracing::info!( + "chunk_by_num {i}th chunk: num_rows {chunk_rows}, num_padding_rows {num_padding_rows}" + ); result.push(Self { num_col: self.num_col, num_padding_rows, values, }); } - assert_eq!(self.num_instances(), result.iter().enumerate().map(|(idx, c)| { - tracing::info!("{idx}chunk num_instances: {}", c.num_instances()); - c.num_instances() - }).sum::()); + assert_eq!( + self.num_instances(), + result + .iter() + .enumerate() + .map(|(idx, c)| { + tracing::info!("{idx}chunk num_instances: {}", c.num_instances()); + c.num_instances() + }) + .sum::() + ); result } } diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index 54c68b055..d870cf9af 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -23,7 +23,7 @@ pub fn ceil_log2(x: usize) -> usize { pub fn create_uninit_vec(len: usize) -> Vec> { let mut vec: Vec> = Vec::with_capacity(len); vec.resize(len, MaybeUninit::uninit()); - //unsafe { vec.set_len(len) }; + // unsafe { vec.set_len(len) }; vec } From a5a9fe4091490563405da1879f11453d4c25569c Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 5 Dec 2024 12:46:50 +0800 Subject: [PATCH 4/8] rename and clean --- ceno_zkvm/src/scheme/prover.rs | 26 ++++++++++------------- ceno_zkvm/src/witness.rs | 38 +++++++++++----------------------- mpcs/src/lib.rs | 3 --- 3 files changed, 23 insertions(+), 44 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 1848cb6f6..09d119293 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -9,7 +9,7 @@ use ff::Field; use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, IntoMLE, MultilinearExtension}, + mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, @@ -92,9 +92,11 @@ impl> ZKVMProver { } exit_span!(span); - let chunk_size = 1 * 1048576; + // TODO: is it better to set different size of different opcode? + let shard_size = 1 * 1048576; // commit to main traces + // TODO: (1) is it ok to store mle? (2) replace tuple with struct? let mut wits_and_commitments: BTreeMap< String, Vec<( @@ -108,7 +110,7 @@ impl> ZKVMProver { // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, witness) in witnesses.into_iter_sorted() { let num_instances = witness.num_instances(); - tracing::warn!( + tracing::debug!( "committing {} witnesses of size {}..", circuit_name, num_instances @@ -123,7 +125,7 @@ impl> ZKVMProver { profiling_2 = true ); - let witness_chunks = witness.chunk_by_num(chunk_size); + let witness_chunks = witness.shard_by_rows(shard_size); if witness_chunks.len() > 1 { tracing::warn!( "split {circuit_name} witness into {} chunks", @@ -133,18 +135,12 @@ impl> ZKVMProver { let witness_and_commitment: Vec<_> = witness_chunks .into_iter() .map(|witness| -> Result<_, ZKVMError> { - // TODO: should we store the mle result? - tracing::debug!("into mle: {}", witness.num_instances()); - let witness_clone = witness.clone(); - tracing::debug!("cloned"); - let witness_mles = witness_clone.into_mles(); - tracing::debug!("batch_commit_and_write"); + let witness_mles = witness.clone().into_mles(); let commitment = PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript) .map_err(ZKVMError::PCSError)?; - tracing::debug!("done"); - let arc_mles = witness_mles.into_iter().map(|v| v.into()).collect_vec(); - Ok((witness, arc_mles, commitment)) + let mles = witness_mles.into_iter().map(|v| v.into()).collect_vec(); + Ok((witness, mles, commitment)) }) .collect::, _>>()?; wits_and_commitments.insert(circuit_name.clone(), witness_and_commitment); @@ -190,13 +186,13 @@ impl> ZKVMProver { cs.w_expressions.len(), cs.lk_expressions.len(), ); - let opcode_proof: Vec<_> = witness_and_wit.into_iter().enumerate().map(|(idx, (witness, arc_mles, wits_commit))| -> Result<_, ZKVMError> { + let opcode_proof: Vec<_> = witness_and_wit.into_iter().enumerate().map(|(idx, (witness, mles, wits_commit))| -> Result<_, ZKVMError> { let num_instances = witness.num_instances(); let proof = self.create_opcode_proof( circuit_name, &self.pk.pp, pk, - arc_mles.into_iter().map(|v| v.into()).collect_vec(), + mles.into_iter().map(|v| v.into()).collect_vec(), wits_commit, &pi, num_instances, diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index eb91c25e9..0e7ea6d33 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -64,7 +64,6 @@ impl RowMajorMatrix { } pub fn num_instances(&self) -> usize { - // tracing::info!("num_instances... {} {} {}", self.values.len(), self.num_col, self.num_padding_rows); (self.values.len() / self.num_col) .checked_sub(self.num_padding_rows) .expect("overflow") @@ -104,7 +103,6 @@ impl RowMajorMatrix { } pub fn de_interleaving(mut self) -> Vec> { - tracing::debug!("de_interleaving.."); (0..self.num_col) .map(|i| { self.values @@ -118,33 +116,29 @@ impl RowMajorMatrix { } // TODO: should we consume or clone `self`? - pub fn chunk_by_num(&self, chunk_rows: usize) -> Vec { + pub fn shard_by_rows(&self, shard_rows: usize) -> Vec { let padded_row_num = self.values.len() / self.num_col; - if padded_row_num <= chunk_rows { + if padded_row_num <= shard_rows { return vec![self.clone()]; } - // padded_row_num and instance_num_per_chunk should both be pow of 2. - assert_eq!(padded_row_num % chunk_rows, 0); - let chunk_num = (self.num_instances() + chunk_rows - 1) / chunk_rows; - let mut result = Vec::new(); + // padded_row_num and chunk_rows should both be pow of 2. + assert_eq!(padded_row_num % shard_rows, 0); + let chunk_num = (self.num_instances() + shard_rows - 1) / shard_rows; + let mut shards = Vec::new(); for i in 0..chunk_num { let mut values: Vec<_> = self.values - [(i * chunk_rows * self.num_col)..((i + 1) * chunk_rows * self.num_col)] + [(i * shard_rows * self.num_col)..((i + 1) * shard_rows * self.num_col)] .to_vec(); let mut num_padding_rows = 0; // Only last chunk contains padding rows. - if i == chunk_num - 1 && self.num_instances() % chunk_rows != 0 { - let num_rows = self.num_instances() % chunk_rows; + if i == chunk_num - 1 && self.num_instances() % shard_rows != 0 { + let num_rows = self.num_instances() % shard_rows; let num_total_rows = next_pow2_instance_padding(num_rows); num_padding_rows = num_total_rows - num_rows; values.truncate(num_total_rows * self.num_col); }; - - tracing::info!( - "chunk_by_num {i}th chunk: num_rows {chunk_rows}, num_padding_rows {num_padding_rows}" - ); - result.push(Self { + shards.push(Self { num_col: self.num_col, num_padding_rows, values, @@ -152,16 +146,9 @@ impl RowMajorMatrix { } assert_eq!( self.num_instances(), - result - .iter() - .enumerate() - .map(|(idx, c)| { - tracing::info!("{idx}chunk num_instances: {}", c.num_instances()); - c.num_instances() - }) - .sum::() + shards.iter().map(|c| { c.num_instances() }).sum::() ); - result + shards } } @@ -169,7 +156,6 @@ impl RowMajorMatrix { pub fn into_mles>( self, ) -> Vec> { - tracing::info!("before de_interleaving"); self.de_interleaving().into_mles() } } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index d012850b2..b6716c19e 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -156,11 +156,8 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { polys: &[DenseMultilinearExtension], transcript: &mut Transcript, ) -> Result { - println!("000"); let comm = Self::batch_commit(pp, polys)?; - println!("001"); Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; - println!("002"); Ok(comm) } From cf83003ec1fa5915b27b1429ea064634f273e1e1 Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 5 Dec 2024 12:58:06 +0800 Subject: [PATCH 5/8] more clean --- ceno_zkvm/src/scheme/prover.rs | 21 ++++++++++----------- ceno_zkvm/src/scheme/verifier.rs | 2 +- ceno_zkvm/src/witness.rs | 14 +++++--------- multilinear_extensions/src/mle.rs | 2 -- multilinear_extensions/src/util.rs | 5 ++--- 5 files changed, 18 insertions(+), 26 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 09d119293..74bd6c72b 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -125,14 +125,14 @@ impl> ZKVMProver { profiling_2 = true ); - let witness_chunks = witness.shard_by_rows(shard_size); - if witness_chunks.len() > 1 { - tracing::warn!( - "split {circuit_name} witness into {} chunks", - witness_chunks.len() + let witness_shards = witness.shard_by_rows(shard_size); + if witness_shards.len() > 1 { + tracing::info!( + "split {circuit_name} witness into {} shards", + witness_shards.len() ); } - let witness_and_commitment: Vec<_> = witness_chunks + let witness_and_commitment: Vec<_> = witness_shards .into_iter() .map(|witness| -> Result<_, ZKVMError> { let witness_mles = witness.clone().into_mles(); @@ -143,7 +143,7 @@ impl> ZKVMProver { Ok((witness, mles, commitment)) }) .collect::, _>>()?; - wits_and_commitments.insert(circuit_name.clone(), witness_and_commitment); + wits_and_commitments.insert(circuit_name, witness_and_commitment); exit_span!(span); } exit_span!(commit_to_traces_span); @@ -200,8 +200,7 @@ impl> ZKVMProver { &challenges, )?; tracing::info!( - "generated proof for opcode {} with num_instances={}, chunk idx {idx}", - circuit_name, num_instances + "generated proof for opcode {circuit_name} with num_instances={num_instances}, shard idx {idx}" ); Ok(proof) }).collect::, _>>()?; @@ -211,13 +210,13 @@ impl> ZKVMProver { .insert(circuit_name.clone(), (i, opcode_proof)); } else { assert_eq!(witness_and_wit.len(), 1); - let (witness, arc_mles, wits_commit) = witness_and_wit.remove(0); + let (witness, mles, wits_commit) = witness_and_wit.remove(0); let num_instances = witness.num_instances(); let (table_proof, pi_in_evals) = self.create_table_proof( circuit_name, &self.pk.pp, pk, - arc_mles.into_iter().map(|v| v.into()).collect_vec(), + mles.into_iter().map(|v| v.into()).collect_vec(), wits_commit, &pi, transcript, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index a586d9a35..4e952f68d 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -61,7 +61,7 @@ impl> ZKVMVerifier does_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. - // seems a bit adhoc here.. + // TODO: make it less adhoc let num_instances = vm_proof .opcode_proofs .get(&HaltInstruction::::name()) diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 0e7ea6d33..68c5a784a 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -41,14 +41,14 @@ macro_rules! set_fixed_val { } #[derive(Clone)] -pub struct RowMajorMatrix { +pub struct RowMajorMatrix { // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance values: Vec>, num_padding_rows: usize, num_col: usize, } -impl RowMajorMatrix { +impl RowMajorMatrix { pub fn new(num_rows: usize, num_col: usize) -> Self { let num_total_rows = next_pow2_instance_padding(num_rows); let num_padding_rows = num_total_rows - num_rows; @@ -59,10 +59,6 @@ impl RowMajorMatrix { } } - pub fn num_col(&self) -> usize { - self.num_col - } - pub fn num_instances(&self) -> usize { (self.values.len() / self.num_col) .checked_sub(self.num_padding_rows) @@ -123,16 +119,16 @@ impl RowMajorMatrix { } // padded_row_num and chunk_rows should both be pow of 2. assert_eq!(padded_row_num % shard_rows, 0); - let chunk_num = (self.num_instances() + shard_rows - 1) / shard_rows; + let shard_num = (self.num_instances() + shard_rows - 1) / shard_rows; let mut shards = Vec::new(); - for i in 0..chunk_num { + for i in 0..shard_num { let mut values: Vec<_> = self.values [(i * shard_rows * self.num_col)..((i + 1) * shard_rows * self.num_col)] .to_vec(); let mut num_padding_rows = 0; // Only last chunk contains padding rows. - if i == chunk_num - 1 && self.num_instances() % shard_rows != 0 { + if i == shard_num - 1 && self.num_instances() % shard_rows != 0 { let num_rows = self.num_instances() % shard_rows; let num_total_rows = next_pow2_instance_padding(num_rows); num_padding_rows = num_total_rows - num_rows; diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 8593f33f0..8a182e645 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -105,7 +105,6 @@ pub trait IntoMLE: Sized { impl IntoMLE> for Vec { fn into_mle(mut self) -> DenseMultilinearExtension { let next_pow2 = self.len().next_power_of_two(); - tracing::info!("into_mle with len {next_pow2}"); self.resize(next_pow2, F::ZERO); DenseMultilinearExtension::from_evaluation_vec_smart::(ceil_log2(next_pow2), self) } @@ -119,7 +118,6 @@ impl> IntoMLEs> { fn into_mles(self) -> Vec> { - tracing::info!("vec-vec into_mles"); self.into_iter().map(|v| v.into_mle()).collect() } } diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index d870cf9af..a0a8e56a2 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -20,10 +20,9 @@ pub fn ceil_log2(x: usize) -> usize { usize_bits - (x - 1).leading_zeros() as usize } -pub fn create_uninit_vec(len: usize) -> Vec> { +pub fn create_uninit_vec(len: usize) -> Vec> { let mut vec: Vec> = Vec::with_capacity(len); - vec.resize(len, MaybeUninit::uninit()); - // unsafe { vec.set_len(len) }; + unsafe { vec.set_len(len) }; vec } From b078060f9ad1bef15725cac98ebb5c0f76c0f52f Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 5 Dec 2024 13:09:59 +0800 Subject: [PATCH 6/8] fmt --- ceno_zkvm/src/scheme/prover.rs | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 74bd6c72b..95c9d4cc4 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -189,22 +189,21 @@ impl> ZKVMProver { let opcode_proof: Vec<_> = witness_and_wit.into_iter().enumerate().map(|(idx, (witness, mles, wits_commit))| -> Result<_, ZKVMError> { let num_instances = witness.num_instances(); let proof = self.create_opcode_proof( - circuit_name, - &self.pk.pp, - pk, - mles.into_iter().map(|v| v.into()).collect_vec(), - wits_commit, - &pi, - num_instances, - transcript, - &challenges, - )?; - tracing::info!( - "generated proof for opcode {circuit_name} with num_instances={num_instances}, shard idx {idx}" - ); - Ok(proof) - }).collect::, _>>()?; - + circuit_name, + &self.pk.pp, + pk, + mles.into_iter().map(|v| v.into()).collect_vec(), + wits_commit, + &pi, + num_instances, + transcript, + &challenges, + )?; + tracing::info!( + "generated proof for opcode {circuit_name} with num_instances={num_instances}, shard idx {idx}" + ); + Ok(proof) + }).collect::, _>>()?; vm_proof .opcode_proofs .insert(circuit_name.clone(), (i, opcode_proof)); From e7922939ada1d4e581e569152ab8b743ed66b7c2 Mon Sep 17 00:00:00 2001 From: Zhang Zhuo Date: Thu, 5 Dec 2024 13:25:20 +0800 Subject: [PATCH 7/8] lint --- ceno_zkvm/src/scheme/prover.rs | 6 +++--- ceno_zkvm/src/scheme/verifier.rs | 2 +- ceno_zkvm/src/witness.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 95c9d4cc4..566dfeeab 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -93,10 +93,11 @@ impl> ZKVMProver { exit_span!(span); // TODO: is it better to set different size of different opcode? - let shard_size = 1 * 1048576; + let shard_size = 1048576; // commit to main traces // TODO: (1) is it ok to store mle? (2) replace tuple with struct? + #[allow(clippy::type_complexity)] let mut wits_and_commitments: BTreeMap< String, Vec<( @@ -139,8 +140,7 @@ impl> ZKVMProver { let commitment = PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript) .map_err(ZKVMError::PCSError)?; - let mles = witness_mles.into_iter().map(|v| v.into()).collect_vec(); - Ok((witness, mles, commitment)) + Ok((witness, witness_mles, commitment)) }) .collect::, _>>()?; wits_and_commitments.insert(circuit_name, witness_and_commitment); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 4e952f68d..e126481c7 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -156,7 +156,7 @@ impl> ZKVMVerifier &name, &self.vk.vp, circuit_vk, - &opcode_proof, + opcode_proof, pi_evals, transcript, NUM_FANIN, diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 68c5a784a..8979ff664 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -119,7 +119,7 @@ impl RowMajorMatrix { } // padded_row_num and chunk_rows should both be pow of 2. assert_eq!(padded_row_num % shard_rows, 0); - let shard_num = (self.num_instances() + shard_rows - 1) / shard_rows; + let shard_num = self.num_instances().div_ceil(shard_rows); let mut shards = Vec::new(); for i in 0..shard_num { let mut values: Vec<_> = self.values From 809aa09b387ec9d705b67d517759cf675bce007d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Wed, 18 Dec 2024 13:50:44 +0100 Subject: [PATCH 8/8] feat/sharding: fix after merge --- ceno_zkvm/src/scheme/prover.rs | 2 +- ceno_zkvm/src/witness.rs | 20 ++++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 20111a1cf..ff6adbc6e 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -103,7 +103,7 @@ impl> ZKVMProver { Vec<( RowMajorMatrix<_>, Vec>, - PCS::CommitmentWithData, + PCS::CommitmentWithWitness, )>, > = BTreeMap::new(); diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 7d72118a5..d05a4629f 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -106,7 +106,7 @@ impl> RowMajorMatrix } pub fn shard_by_rows(&self, shard_rows: usize) -> Vec { - let padded_row_num = self.values.len() / self.num_col; + let padded_row_num = self.num_instances() + self.num_padding_instances(); if padded_row_num <= shard_rows { return vec![self.clone()]; } @@ -115,22 +115,14 @@ impl> RowMajorMatrix let shard_num = self.num_instances().div_ceil(shard_rows); let mut shards = Vec::new(); for i in 0..shard_num { - let mut values: Vec<_> = self.values - [(i * shard_rows * self.num_col)..((i + 1) * shard_rows * self.num_col)] - .to_vec(); - let mut num_padding_rows = 0; - - // Only last chunk contains padding rows. - if i == shard_num - 1 && self.num_instances() % shard_rows != 0 { - let num_rows = self.num_instances() % shard_rows; - let num_total_rows = next_pow2_instance_padding(num_rows); - num_padding_rows = num_total_rows - num_rows; - values.truncate(num_total_rows * self.num_col); - }; + let start = i * shard_rows * self.num_col; + let end = ((i + 1) * shard_rows * self.num_col).min(self.values.len()); + let values: Vec<_> = self.values[start..end].to_vec(); + shards.push(Self { num_col: self.num_col, - num_padding_rows, values, + padding_strategy: self.padding_strategy.clone(), }); } assert_eq!(