-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: sharding of global challenge phase commitment and opcode proving #695
base: master
Are you sure you want to change the base?
Changes from 8 commits
5673a2b
eafc4f8
f0693f9
2fde280
a5a9fe4
cf83003
b078060
e792293
3dda9c0
809aa09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
use core::assert_eq; | ||
use ff_ext::ExtensionField; | ||
use std::{ | ||
collections::{BTreeMap, BTreeSet, HashMap}, | ||
|
@@ -8,7 +9,7 @@ use ff::Field; | |
use itertools::{Itertools, enumerate, izip}; | ||
use mpcs::PolynomialCommitmentScheme; | ||
use multilinear_extensions::{ | ||
mle::{IntoMLE, MultilinearExtension}, | ||
mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}, | ||
util::ceil_log2, | ||
virtual_poly::build_eq_x_r_vec, | ||
virtual_poly_v2::ArcMultilinearExtension, | ||
|
@@ -36,6 +37,7 @@ use crate::{ | |
}, | ||
utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, | ||
virtual_polys::VirtualPolynomials, | ||
witness::RowMajorMatrix, | ||
}; | ||
|
||
use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; | ||
|
@@ -90,33 +92,59 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> { | |
} | ||
exit_span!(span); | ||
|
||
// TODO: is it better to set different size of different opcode? | ||
let shard_size = 1048576; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be easier to read in hex. (Because it's not just a random number.) |
||
|
||
// commit to main traces | ||
let mut commitments = BTreeMap::new(); | ||
let mut wits = BTreeMap::new(); | ||
// TODO: (1) is it ok to store mle? (2) replace tuple with struct? | ||
#[allow(clippy::type_complexity)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps use some type synonym or so? |
||
let mut wits_and_commitments: BTreeMap< | ||
String, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the meaning of the key here? A type synonym might be useful here? |
||
Vec<( | ||
RowMajorMatrix<_>, | ||
Vec<DenseMultilinearExtension<_>>, | ||
PCS::CommitmentWithData, | ||
)>, | ||
> = BTreeMap::new(); | ||
|
||
let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); | ||
// commit to opcode circuits first and then commit to table circuits, sorted by name | ||
for (circuit_name, witness) in witnesses.into_iter_sorted() { | ||
let num_instances = witness.num_instances(); | ||
tracing::debug!( | ||
"committing {} witnesses of size {}..", | ||
circuit_name, | ||
num_instances | ||
); | ||
if num_instances == 0 { | ||
wits_and_commitments.insert(circuit_name.clone(), Vec::new()); | ||
continue; | ||
} | ||
let span = entered_span!( | ||
"commit to iteration", | ||
circuit_name = circuit_name, | ||
profiling_2 = true | ||
); | ||
let witness = match num_instances { | ||
0 => vec![], | ||
_ => { | ||
let witness = witness.into_mles(); | ||
commitments.insert( | ||
circuit_name.clone(), | ||
PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) | ||
.map_err(ZKVMError::PCSError)?, | ||
); | ||
witness | ||
} | ||
}; | ||
|
||
let witness_shards = witness.shard_by_rows(shard_size); | ||
if witness_shards.len() > 1 { | ||
tracing::info!( | ||
"split {circuit_name} witness into {} shards", | ||
witness_shards.len() | ||
); | ||
} | ||
let witness_and_commitment: Vec<_> = witness_shards | ||
.into_iter() | ||
.map(|witness| -> Result<_, ZKVMError> { | ||
let witness_mles = witness.clone().into_mles(); | ||
let commitment = | ||
PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript) | ||
.map_err(ZKVMError::PCSError)?; | ||
Ok((witness, witness_mles, commitment)) | ||
}) | ||
.collect::<Result<Vec<_>, _>>()?; | ||
wits_and_commitments.insert(circuit_name, witness_and_commitment); | ||
exit_span!(span); | ||
wits.insert(circuit_name, (witness, num_instances)); | ||
} | ||
exit_span!(commit_to_traces_span); | ||
|
||
|
@@ -135,13 +163,14 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> { | |
.iter() // Sorted by key. | ||
.zip_eq(transcripts.iter_mut().enumerate()) | ||
{ | ||
let (witness, num_instances) = wits | ||
let mut witness_and_wit: Vec<_> = wits_and_commitments | ||
.remove(circuit_name) | ||
.ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; | ||
if witness.is_empty() { | ||
|
||
if witness_and_wit.is_empty() { | ||
continue; | ||
} | ||
let wits_commit = commitments.remove(circuit_name).unwrap(); | ||
|
||
// TODO: add an enum for circuit type either in constraint_system or vk | ||
let cs = pk.get_cs(); | ||
let is_opcode_circuit = cs.lk_table_expressions.is_empty() | ||
|
@@ -157,31 +186,36 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> { | |
cs.w_expressions.len(), | ||
cs.lk_expressions.len(), | ||
); | ||
let opcode_proof = self.create_opcode_proof( | ||
circuit_name, | ||
&self.pk.pp, | ||
pk, | ||
witness.into_iter().map(|w| w.into()).collect_vec(), | ||
wits_commit, | ||
&pi, | ||
num_instances, | ||
transcript, | ||
&challenges, | ||
)?; | ||
tracing::info!( | ||
"generated proof for opcode {} with num_instances={}", | ||
circuit_name, | ||
num_instances | ||
); | ||
let opcode_proof: Vec<_> = witness_and_wit.into_iter().enumerate().map(|(idx, (witness, mles, wits_commit))| -> Result<_, ZKVMError> { | ||
let num_instances = witness.num_instances(); | ||
let proof = self.create_opcode_proof( | ||
circuit_name, | ||
&self.pk.pp, | ||
pk, | ||
mles.into_iter().map(|v| v.into()).collect_vec(), | ||
wits_commit, | ||
&pi, | ||
num_instances, | ||
transcript, | ||
&challenges, | ||
)?; | ||
tracing::info!( | ||
"generated proof for opcode {circuit_name} with num_instances={num_instances}, shard idx {idx}" | ||
); | ||
Ok(proof) | ||
}).collect::<Result<Vec<_>, _>>()?; | ||
vm_proof | ||
.opcode_proofs | ||
.insert(circuit_name.clone(), (i, opcode_proof)); | ||
} else { | ||
assert_eq!(witness_and_wit.len(), 1); | ||
let (witness, mles, wits_commit) = witness_and_wit.remove(0); | ||
let num_instances = witness.num_instances(); | ||
let (table_proof, pi_in_evals) = self.create_table_proof( | ||
circuit_name, | ||
&self.pk.pp, | ||
pk, | ||
witness.into_iter().map(|v| v.into()).collect_vec(), | ||
mles.into_iter().map(|v| v.into()).collect_vec(), | ||
wits_commit, | ||
&pi, | ||
transcript, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,10 +61,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS> | |
does_halt: bool, | ||
) -> Result<bool, ZKVMError> { | ||
// require ecall/halt proof to exist, depending whether we expect a halt. | ||
// TODO: make it less adhoc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Especially once we support more ecalls. Or perhaps we should introduce a specific 'halt-successfully' introduction. |
||
let num_instances = vm_proof | ||
.opcode_proofs | ||
.get(&HaltInstruction::<E>::name()) | ||
.map(|(_, p)| p.num_instances) | ||
.map(|(_, p)| p[0].num_instances) | ||
.unwrap_or(0); | ||
if num_instances != (does_halt as usize) { | ||
return Err(ZKVMError::VerifyError(format!( | ||
|
@@ -119,8 +120,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS> | |
|
||
for (name, (_, proof)) in vm_proof.opcode_proofs.iter() { | ||
tracing::debug!("read {}'s commit", name); | ||
PCS::write_commitment(&proof.wits_commit, &mut transcript) | ||
.map_err(ZKVMError::PCSError)?; | ||
for p in proof { | ||
PCS::write_commitment(&p.wits_commit, &mut transcript) | ||
.map_err(ZKVMError::PCSError)?; | ||
} | ||
} | ||
for (name, (_, proof)) in vm_proof.table_proofs.iter() { | ||
tracing::debug!("read {}'s commit", name); | ||
|
@@ -140,43 +143,47 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS> | |
let point_eval = PointAndEval::default(); | ||
let mut transcripts = transcript.fork(self.vk.circuit_vks.len()); | ||
|
||
for (name, (i, opcode_proof)) in vm_proof.opcode_proofs { | ||
for (name, (i, opcode_proofs)) in vm_proof.opcode_proofs { | ||
let transcript = &mut transcripts[i]; | ||
|
||
let circuit_vk = self | ||
.vk | ||
.circuit_vks | ||
.get(&name) | ||
.ok_or(ZKVMError::VKNotFound(name.clone()))?; | ||
let _rand_point = self.verify_opcode_proof( | ||
&name, | ||
&self.vk.vp, | ||
circuit_vk, | ||
&opcode_proof, | ||
pi_evals, | ||
transcript, | ||
NUM_FANIN, | ||
&point_eval, | ||
&challenges, | ||
)?; | ||
tracing::info!("verified proof for opcode {}", name); | ||
for opcode_proof in &opcode_proofs { | ||
let _rand_point = self.verify_opcode_proof( | ||
&name, | ||
&self.vk.vp, | ||
circuit_vk, | ||
opcode_proof, | ||
pi_evals, | ||
transcript, | ||
NUM_FANIN, | ||
&point_eval, | ||
&challenges, | ||
)?; | ||
} | ||
|
||
// getting the number of dummy padding item that we used in this opcode circuit | ||
let num_lks = circuit_vk.get_cs().lk_expressions.len(); | ||
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; | ||
let num_padded_instance = | ||
next_pow2_instance_padding(opcode_proof.num_instances) - opcode_proof.num_instances; | ||
dummy_table_item_multiplicity += num_padded_lks_per_instance | ||
* opcode_proof.num_instances | ||
+ num_lks.next_power_of_two() * num_padded_instance; | ||
|
||
prod_r *= opcode_proof.record_r_out_evals.iter().product::<E>(); | ||
prod_w *= opcode_proof.record_w_out_evals.iter().product::<E>(); | ||
|
||
logup_sum += | ||
opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); | ||
logup_sum += | ||
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); | ||
tracing::info!("verified proof for opcode {}", name); | ||
for opcode_proof in &opcode_proofs { | ||
// getting the number of dummy padding item that we used in this opcode circuit | ||
let num_lks = circuit_vk.get_cs().lk_expressions.len(); | ||
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; | ||
let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances) | ||
- opcode_proof.num_instances; | ||
dummy_table_item_multiplicity += num_padded_lks_per_instance | ||
* opcode_proof.num_instances | ||
+ num_lks.next_power_of_two() * num_padded_instance; | ||
|
||
prod_r *= opcode_proof.record_r_out_evals.iter().product::<E>(); | ||
prod_w *= opcode_proof.record_w_out_evals.iter().product::<E>(); | ||
|
||
logup_sum += | ||
opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); | ||
logup_sum += | ||
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); | ||
} | ||
} | ||
|
||
for (name, (i, table_proof)) in vm_proof.table_proofs { | ||
|
@@ -471,6 +478,17 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS> | |
} | ||
|
||
// verify zero expression (degree = 1) statement, thus no sumcheck | ||
for (expr, name) in cs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this logic seems some left over. probably we combine with line 492? |
||
.assert_zero_expressions | ||
.iter() | ||
.zip_eq(cs.assert_zero_expressions_namespace_map.iter()) | ||
{ | ||
if eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) | ||
!= E::ZERO | ||
{ | ||
tracing::error!("checking zero expression {name} failed."); | ||
} | ||
} | ||
if cs.assert_zero_expressions.iter().any(|expr| { | ||
eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO | ||
}) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
use core::assert_eq; | ||
use ff::Field; | ||
use std::{ | ||
array, | ||
|
@@ -59,7 +60,9 @@ impl<T: Sized + Sync + Clone + Send + Copy> RowMajorMatrix<T> { | |
} | ||
|
||
pub fn num_instances(&self) -> usize { | ||
self.values.len() / self.num_col - self.num_padding_rows | ||
(self.values.len() / self.num_col) | ||
.checked_sub(self.num_padding_rows) | ||
.expect("overflow") | ||
} | ||
|
||
pub fn num_padding_instances(&self) -> usize { | ||
|
@@ -107,6 +110,42 @@ impl<T: Sized + Sync + Clone + Send + Copy> RowMajorMatrix<T> { | |
}) | ||
.collect() | ||
} | ||
|
||
// TODO: should we consume or clone `self`? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's natural to consume |
||
pub fn shard_by_rows(&self, shard_rows: usize) -> Vec<Self> { | ||
let padded_row_num = self.values.len() / self.num_col; | ||
if padded_row_num <= shard_rows { | ||
return vec![self.clone()]; | ||
} | ||
// padded_row_num and chunk_rows should both be pow of 2. | ||
assert_eq!(padded_row_num % shard_rows, 0); | ||
let shard_num = self.num_instances().div_ceil(shard_rows); | ||
let mut shards = Vec::new(); | ||
for i in 0..shard_num { | ||
let mut values: Vec<_> = self.values | ||
[(i * shard_rows * self.num_col)..((i + 1) * shard_rows * self.num_col)] | ||
.to_vec(); | ||
let mut num_padding_rows = 0; | ||
|
||
// Only last chunk contains padding rows. | ||
if i == shard_num - 1 && self.num_instances() % shard_rows != 0 { | ||
let num_rows = self.num_instances() % shard_rows; | ||
let num_total_rows = next_pow2_instance_padding(num_rows); | ||
num_padding_rows = num_total_rows - num_rows; | ||
values.truncate(num_total_rows * self.num_col); | ||
}; | ||
shards.push(Self { | ||
num_col: self.num_col, | ||
num_padding_rows, | ||
values, | ||
}); | ||
} | ||
assert_eq!( | ||
self.num_instances(), | ||
shards.iter().map(|c| { c.num_instances() }).sum::<usize>() | ||
); | ||
shards | ||
} | ||
} | ||
|
||
impl<F: Field> RowMajorMatrix<F> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so. Why choosing
1048576
here?