From f1de6091fc9ff036f636d29c29b4acb7d2b40420 Mon Sep 17 00:00:00 2001 From: even <35983442+10to4@users.noreply.github.com> Date: Thu, 26 Dec 2024 13:38:26 +0800 Subject: [PATCH] Feat/structural witin add (#740) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Work for #654 --------- Co-authored-by: sm.wu Co-authored-by: Matthias Görgens Co-authored-by: Ho Co-authored-by: mcalancea Co-authored-by: noelwei --- ceno_zkvm/src/chip_handler/general.rs | 17 +- ceno_zkvm/src/circuit_builder.rs | 51 +++-- ceno_zkvm/src/expression.rs | 215 ++++++++++++++------ ceno_zkvm/src/expression/monomial.rs | 6 +- ceno_zkvm/src/lib.rs | 1 + ceno_zkvm/src/scheme/mock_prover.rs | 102 +++++----- ceno_zkvm/src/scheme/prover.rs | 91 ++++++--- ceno_zkvm/src/scheme/utils.rs | 11 +- ceno_zkvm/src/scheme/verifier.rs | 149 +++++++------- ceno_zkvm/src/structs.rs | 3 +- ceno_zkvm/src/tables/mod.rs | 1 + ceno_zkvm/src/tables/ops/ops_circuit.rs | 3 +- ceno_zkvm/src/tables/ops/ops_impl.rs | 8 +- ceno_zkvm/src/tables/program.rs | 4 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 9 +- ceno_zkvm/src/tables/ram/ram_impl.rs | 87 ++++---- ceno_zkvm/src/tables/range/range_circuit.rs | 3 +- ceno_zkvm/src/tables/range/range_impl.rs | 8 +- ceno_zkvm/src/uint/arithmetic.rs | 8 +- ceno_zkvm/src/virtual_polys.rs | 1 + 20 files changed, 499 insertions(+), 279 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 780914e78..4fe85f3d9 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -3,7 +3,7 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem, SetTableSpec}, error::ZKVMError, - expression::{Expression, Fixed, Instance, ToExpr, WitIn}, + expression::{Expression, Fixed, Instance, StructuralWitIn, ToExpr, WitIn}, instructions::riscv::constants::{ END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, UINT_LIMBS, @@ -28,6 +28,21 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.create_witin(name_fn) } + pub fn create_structural_witin( + &mut self, + name_fn: N, + max_len: usize, + offset: u32, + multi_factor: usize, + ) -> StructuralWitIn + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .create_structural_witin(name_fn, max_len, offset, multi_factor) + } + pub fn create_fixed(&mut self, name_fn: N) -> Result where NR: Into, diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 5b82168b0..2e3b585c3 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -9,7 +9,7 @@ use crate::{ ROMType, chip_handler::utils::rlc_chip_record, error::ZKVMError, - expression::{Expression, Fixed, Instance, WitIn}, + expression::{Expression, Fixed, Instance, StructuralWitIn, WitIn}, structs::{ProgramParams, ProvingKey, RAMType, VerifyingKey, WitnessId}, witness::RowMajorMatrix, }; @@ -59,14 +59,6 @@ pub struct LogupTableExpression { pub table_len: usize, } -// TODO encapsulate few information of table spec to SetTableAddrType value -// once confirm syntax is friendly and parsed by recursive verifier -#[derive(Clone, Debug)] -pub enum SetTableAddrType { - FixedAddr, - DynamicAddr(DynamicAddr), -} - #[derive(Clone, Debug)] pub struct DynamicAddr { pub addr_witin_id: usize, @@ -75,12 +67,13 @@ pub struct DynamicAddr { #[derive(Clone, Debug)] pub struct SetTableSpec { - pub addr_type: SetTableAddrType, - pub len: usize, + pub len: Option, + pub structural_witins: Vec, } #[derive(Clone, Debug)] pub struct SetTableExpression { + /// table expression pub expr: Expression, // TODO make decision to have enum/struct @@ -92,10 +85,12 @@ pub struct SetTableExpression { pub struct ConstraintSystem { pub(crate) ns: NameSpace, - // pub platform: Platform, pub num_witin: WitnessId, pub witin_namespace_map: Vec, + pub num_structural_witin: WitnessId, + pub structural_witin_namespace_map: Vec, + pub num_fixed: usize, pub fixed_namespace_map: Vec, @@ -152,6 +147,8 @@ impl ConstraintSystem { num_witin: 0, // platform, witin_namespace_map: vec![], + num_structural_witin: 0, + structural_witin_namespace_map: vec![], num_fixed: 0, fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), @@ -209,13 +206,8 @@ impl ConstraintSystem { } pub fn create_witin, N: FnOnce() -> NR>(&mut self, n: N) -> WitIn { - let wit_in = WitIn { - id: { - let id = self.num_witin; - self.num_witin = self.num_witin.strict_add(1); - id - }, - }; + let wit_in = WitIn { id: self.num_witin }; + self.num_witin = self.num_witin.strict_add(1); let path = self.ns.compute_path(n().into()); self.witin_namespace_map.push(path); @@ -223,6 +215,27 @@ impl ConstraintSystem { wit_in } + pub fn create_structural_witin, N: FnOnce() -> NR>( + &mut self, + n: N, + max_len: usize, + offset: u32, + multi_factor: usize, + ) -> StructuralWitIn { + let wit_in = StructuralWitIn { + id: self.num_structural_witin, + max_len, + offset, + multi_factor, + }; + self.num_structural_witin = self.num_structural_witin.strict_add(1); + + let path = self.ns.compute_path(n().into()); + self.structural_witin_namespace_map.push(path); + + wit_in + } + pub fn create_fixed, N: FnOnce() -> NR>( &mut self, n: N, diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index c2b523014..82fb9fe97 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -24,6 +24,10 @@ use crate::{ pub enum Expression { /// WitIn(Id) WitIn(WitnessId), + /// StructuralWitIn is similar with WitIn, but it is structured. + /// These witnesses in StructuralWitIn allow succinct verification directly during the verification processing, rather than requiring a commitment. + /// StructuralWitIn(Id, max_len, offset, multi_factor) + StructuralWitIn(WitnessId, usize, u32, usize), /// This multi-linear polynomial is known at the setup/keygen phase. Fixed(Fixed), /// Public Values @@ -56,6 +60,7 @@ impl Expression { match self { Expression::Fixed(_) => 1, Expression::WitIn(_) => 1, + Expression::StructuralWitIn(..) => 1, Expression::Instance(_) => 0, Expression::Constant(_) => 0, Expression::Sum(a_expr, b_expr) => max(a_expr.degree(), b_expr.degree()), @@ -70,6 +75,7 @@ impl Expression { &self, fixed_in: &impl Fn(&Fixed) -> T, wit_in: &impl Fn(WitnessId) -> T, // witin id + structural_wit_in: &impl Fn(WitnessId, usize, u32, usize) -> T, constant: &impl Fn(E::BaseField) -> T, challenge: &impl Fn(ChallengeId, usize, E, E) -> T, sum: &impl Fn(T, T) -> T, @@ -79,6 +85,7 @@ impl Expression { self.evaluate_with_instance( fixed_in, wit_in, + structural_wit_in, &|_| unreachable!(), constant, challenge, @@ -93,6 +100,7 @@ impl Expression { &self, fixed_in: &impl Fn(&Fixed) -> T, wit_in: &impl Fn(WitnessId) -> T, // witin id + structural_wit_in: &impl Fn(WitnessId, usize, u32, usize) -> T, instance: &impl Fn(Instance) -> T, constant: &impl Fn(E::BaseField) -> T, challenge: &impl Fn(ChallengeId, usize, E, E) -> T, @@ -103,35 +111,94 @@ impl Expression { match self { Expression::Fixed(f) => fixed_in(f), Expression::WitIn(witness_id) => wit_in(*witness_id), + Expression::StructuralWitIn(witness_id, max_len, offset, multi_factor) => { + structural_wit_in(*witness_id, *max_len, *offset, *multi_factor) + } Expression::Instance(i) => instance(*i), Expression::Constant(scalar) => constant(*scalar), Expression::Sum(a, b) => { let a = a.evaluate_with_instance( - fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + fixed_in, + wit_in, + structural_wit_in, + instance, + constant, + challenge, + sum, + product, + scaled, ); let b = b.evaluate_with_instance( - fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + fixed_in, + wit_in, + structural_wit_in, + instance, + constant, + challenge, + sum, + product, + scaled, ); sum(a, b) } Expression::Product(a, b) => { let a = a.evaluate_with_instance( - fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + fixed_in, + wit_in, + structural_wit_in, + instance, + constant, + challenge, + sum, + product, + scaled, ); let b = b.evaluate_with_instance( - fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + fixed_in, + wit_in, + structural_wit_in, + instance, + constant, + challenge, + sum, + product, + scaled, ); product(a, b) } Expression::ScaledSum(x, a, b) => { let x = x.evaluate_with_instance( - fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + fixed_in, + wit_in, + structural_wit_in, + instance, + constant, + challenge, + sum, + product, + scaled, ); let a = a.evaluate_with_instance( - fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + fixed_in, + wit_in, + structural_wit_in, + instance, + constant, + challenge, + sum, + product, + scaled, ); let b = b.evaluate_with_instance( - fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, + fixed_in, + wit_in, + structural_wit_in, + instance, + constant, + challenge, + sum, + product, + scaled, ); scaled(x, a, b) } @@ -149,17 +216,11 @@ impl Expression { self.to_monomial_form_inner() } - pub fn unpack_sum(&self) -> Option<(Expression, Expression)> { - match self { - Expression::Sum(a, b) => Some((a.deref().clone(), b.deref().clone())), - _ => None, - } - } - fn is_zero_expr(expr: &Expression) -> bool { match expr { Expression::Fixed(_) => false, Expression::WitIn(_) => false, + Expression::StructuralWitIn(..) => false, Expression::Instance(_) => false, Expression::Constant(c) => *c == E::BaseField::ZERO, Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b), @@ -176,6 +237,7 @@ impl Expression { ( Expression::Fixed(_) | Expression::WitIn(_) + | Expression::StructuralWitIn(..) | Expression::Challenge(..) | Expression::Constant(_) | Expression::Instance(_), @@ -206,25 +268,18 @@ impl Neg for Expression { type Output = Expression; fn neg(self) -> Self::Output { match self { - Expression::Fixed(_) | Expression::WitIn(_) | Expression::Instance(_) => { - Expression::ScaledSum( - Box::new(self), - Box::new(Expression::Constant(E::BaseField::ONE.neg())), - Box::new(Expression::Constant(E::BaseField::ZERO)), - ) - } - Expression::Constant(c1) => Expression::Constant(c1.neg()), - Expression::Sum(a, b) => { - Expression::Sum(Box::new(-a.deref().clone()), Box::new(-b.deref().clone())) - } - Expression::Product(a, b) => { - Expression::Product(Box::new(-a.deref().clone()), Box::new(b.deref().clone())) - } - Expression::ScaledSum(x, a, b) => Expression::ScaledSum( - x, - Box::new(-a.deref().clone()), - Box::new(-b.deref().clone()), + Expression::Fixed(_) + | Expression::WitIn(_) + | Expression::StructuralWitIn(..) + | Expression::Instance(_) => Expression::ScaledSum( + Box::new(self), + Box::new(Expression::Constant(-E::BaseField::ONE)), + Box::new(Expression::Constant(E::BaseField::ZERO)), ), + Expression::Constant(c1) => Expression::Constant(-c1), + Expression::Sum(a, b) => Expression::Sum(-a, -b), + Expression::Product(a, b) => Expression::Product(-a, b.clone()), + Expression::ScaledSum(x, a, b) => Expression::ScaledSum(x, -a, -b), Expression::Challenge(challenge_id, pow, scalar, offset) => { Expression::Challenge(challenge_id, pow, scalar.neg(), offset.neg()) } @@ -232,6 +287,27 @@ impl Neg for Expression { } } +impl Neg for &Expression { + type Output = Expression; + fn neg(self) -> Self::Output { + self.clone().neg() + } +} + +impl Neg for Box> { + type Output = Box>; + fn neg(self) -> Self::Output { + self.deref().clone().neg().into() + } +} + +impl Neg for &Box> { + type Output = Box>; + fn neg(self) -> Self::Output { + self.clone().neg() + } +} + impl Add for Expression { type Output = Expression; fn add(self, rhs: Expression) -> Expression { @@ -303,11 +379,7 @@ impl Add for Expression { // constant + scaled sum (c1 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) | (Expression::ScaledSum(x, a, b), c1 @ Expression::Constant(_)) => { - Expression::ScaledSum( - x.clone(), - a.clone(), - Box::new(b.deref().clone() + c1.clone()), - ) + Expression::ScaledSum(x.clone(), a.clone(), Box::new(b.deref() + c1)) } _ => Expression::Sum(Box::new(self), Box::new(rhs)), @@ -454,38 +526,22 @@ impl Sub for Expression { // constant - scalesum (c1 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) => { - Expression::ScaledSum( - x.clone(), - Box::new(-a.deref().clone()), - Box::new(c1.clone() - b.deref().clone()), - ) + Expression::ScaledSum(x.clone(), -a, Box::new(c1 - b.deref())) } // scalesum - constant (Expression::ScaledSum(x, a, b), c1 @ Expression::Constant(_)) => { - Expression::ScaledSum( - x.clone(), - a.clone(), - Box::new(b.deref().clone() - c1.clone()), - ) + Expression::ScaledSum(x.clone(), a.clone(), Box::new(b.deref() - c1)) } // challenge - scalesum (c1 @ Expression::Challenge(..), Expression::ScaledSum(x, a, b)) => { - Expression::ScaledSum( - x.clone(), - Box::new(-a.deref().clone()), - Box::new(c1.clone() - b.deref().clone()), - ) + Expression::ScaledSum(x.clone(), -a, Box::new(c1 - b.deref())) } // scalesum - challenge (Expression::ScaledSum(x, a, b), c1 @ Expression::Challenge(..)) => { - Expression::ScaledSum( - x.clone(), - a.clone(), - Box::new(b.deref().clone() - c1.clone()), - ) + Expression::ScaledSum(x.clone(), a.clone(), Box::new(b.deref() - c1)) } _ => Expression::Sum(Box::new(self), Box::new(-rhs)), @@ -702,8 +758,8 @@ impl Mul for Expression { | (c2 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) => { Expression::ScaledSum( x.clone(), - Box::new(a.deref().clone() * c2.clone()), - Box::new(b.deref().clone() * c2.clone()), + Box::new(a.deref() * c2), + Box::new(b.deref() * c2), ) } // scaled * challenge => scaled @@ -711,8 +767,8 @@ impl Mul for Expression { | (c2 @ Expression::Challenge(..), Expression::ScaledSum(x, a, b)) => { Expression::ScaledSum( x.clone(), - Box::new(a.deref().clone() * c2.clone()), - Box::new(b.deref().clone() * c2.clone()), + Box::new(a.deref() * c2), + Box::new(b.deref() * c2), ) } _ => Expression::Product(Box::new(self), Box::new(rhs)), @@ -725,6 +781,13 @@ pub struct WitIn { pub id: WitnessId, } +#[derive(Clone, Debug, Copy)] +pub struct StructuralWitIn { + pub id: WitnessId, + pub max_len: usize, + pub offset: u32, + pub multi_factor: usize, +} #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct Fixed(pub usize); @@ -760,6 +823,12 @@ impl WitIn { } } +impl StructuralWitIn { + pub fn assign(&self, instance: &mut [E::BaseField], value: E::BaseField) { + instance[self.id as usize] = value; + } +} + pub trait ToExpr { type Output; fn expr(&self) -> Self::Output; @@ -779,6 +848,20 @@ impl ToExpr for &WitIn { } } +impl ToExpr for StructuralWitIn { + type Output = Expression; + fn expr(&self) -> Expression { + Expression::StructuralWitIn(self.id, self.max_len, self.offset, self.multi_factor) + } +} + +impl ToExpr for &StructuralWitIn { + type Output = Expression; + fn expr(&self) -> Expression { + Expression::StructuralWitIn(self.id, self.max_len, self.offset, self.multi_factor) + } +} + impl ToExpr for Fixed { type Output = Expression; fn expr(&self) -> Expression { @@ -818,8 +901,8 @@ macro_rules! impl_from_via_ToExpr { )* }; } -impl_from_via_ToExpr!(WitIn, Fixed, Instance); -impl_from_via_ToExpr!(&WitIn, &Fixed, &Instance); +impl_from_via_ToExpr!(WitIn, Fixed, StructuralWitIn, Instance); +impl_from_via_ToExpr!(&WitIn, &Fixed, &StructuralWitIn, &Instance); // Implement From trait for unsigned types of at most 64 bits macro_rules! impl_from_unsigned { @@ -873,6 +956,12 @@ pub mod fmt { } format!("WitIn({})", wit_in) } + Expression::StructuralWitIn(wit_in, max_len, offset, multi_factor) => { + format!( + "StructuralWitIn({}, {}, {}, {})", + wit_in, max_len, offset, multi_factor + ) + } Expression::Challenge(id, pow, scaler, offset) => { if *pow == 1 && *scaler == 1.into() && *offset == 0.into() { format!("Challenge({})", id) diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs index 7bd41db17..1cbb3a952 100644 --- a/ceno_zkvm/src/expression/monomial.rs +++ b/ceno_zkvm/src/expression/monomial.rs @@ -19,7 +19,7 @@ impl Expression { }] } - Fixed(_) | WitIn(_) | Instance(_) | Challenge(..) => { + Fixed(_) | WitIn(_) | StructuralWitIn(..) | Instance(_) | Challenge(..) => { vec![Term { coeff: Expression::ONE, vars: vec![self.clone()], @@ -146,6 +146,8 @@ mod tests { E::random(&mut rng), E::random(&mut rng), ]; - move |expr: &Expression| eval_by_expr_with_fixed(&fixed, &witnesses, &challenges, expr) + move |expr: &Expression| { + eval_by_expr_with_fixed(&fixed, &witnesses, &[], &challenges, expr) + } } } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 945404ff3..6c93eba65 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -3,6 +3,7 @@ #![feature(stmt_expr_attributes)] #![feature(variant_count)] #![feature(strict_overflow_ops)] +#![feature(let_chains)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 037f8a130..054a5b3cf 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -31,7 +31,6 @@ use std::{ hash::Hash, io::{BufReader, ErrorKind}, marker::PhantomData, - ops::Neg, sync::OnceLock, }; use strum::IntoEnumIterator; @@ -310,7 +309,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> for i in RANGE::content() { let rlc_record = cb.rlc_chip_record(vec![(RANGE::ROM_TYPE as usize).into(), (i as usize).into()]); - let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + let rlc_record = eval_by_expr(&[], &[], &challenge, &rlc_record); t_vec.push(rlc_record.to_canonical_u64_vec()); } } @@ -327,7 +326,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> (b as usize).into(), (c as usize).into(), ]); - let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + let rlc_record = eval_by_expr(&[], &[], &challenge, &rlc_record); t_vec.push(rlc_record.to_canonical_u64_vec()); } } @@ -361,7 +360,7 @@ fn load_once_tables( let base64_encoded = STANDARD_NO_PAD.encode(serde_json::to_string(&challenge).unwrap().as_bytes()); let file_path = format!("table_cache_dev_{:?}.json", base64_encoded); - let table = match File::open(file_path.clone()) { + let table = match File::open(&file_path) { Ok(file) => { let reader = BufReader::new(file); serde_json::from_reader(reader).unwrap() @@ -469,14 +468,15 @@ impl<'a, E: ExtensionField + Hash> MockProver { // require_equal does not always have the form of Expr::Sum as // the sum of witness and constant is expressed as scaled sum - if name.contains("require_equal") && expr.unpack_sum().is_some() { - let (left, right) = expr.unpack_sum().unwrap(); - let right = right.neg(); + if let Expression::Sum(left, right) = expr + && name.contains("require_equal") + { + let right = -right.as_ref(); - let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left); + let left_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, left); let left_evaluated = left_evaluated.get_base_field_vec(); - let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right); + let right_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, &right); let right_evaluated = right_evaluated.get_base_field_vec(); // left_evaluated.len() ?= right_evaluated.len() due to padding instance @@ -485,7 +485,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { { if left_element != right_element { errors.push(MockProverError::AssertEqualError { - left_expression: left.clone(), + left_expression: *left.clone(), right_expression: right.clone(), left: *left_element, right: *right_element, @@ -496,7 +496,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } else { // contains require_zero - let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_base_field_vec(); for (inst_id, element) in enumerate(expr_evaluated) { @@ -519,7 +519,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .iter() .zip_eq(cb.cs.lk_expressions_namespace_map.iter()) { - let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); + let expr_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec @@ -550,7 +550,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { .map(|expr| { // TODO generalized to all inst_id let inst_id = 0; - wit_infer_by_expr(&[], wits_in, pi, &challenge, expr) + wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, expr) .get_base_field_vec()[inst_id] .to_canonical_u64() }) @@ -649,7 +649,8 @@ impl<'a, E: ExtensionField + Hash> MockProver { for row in fixed.iter_rows() { // TODO: Find a better way to obtain the row content. let row = row.iter().map(|v| (*v).into()).collect::>(); - let rlc_record = eval_by_expr_with_fixed(&row, &[], &challenge, &table_expr.values); + let rlc_record = + eval_by_expr_with_fixed(&row, &[], &[], &challenge, &table_expr.values); t_vec.push(rlc_record.to_canonical_u64_vec()); } } @@ -809,21 +810,20 @@ Hints: num_rows ); // gather lookup inputs - for ((expr, annotation), (rom_type, values)) in cs - .lk_expressions - .iter() - .zip(cs.lk_expressions_namespace_map.clone().into_iter()) - .zip(cs.lk_expressions_items_map.clone().into_iter()) - { + for (expr, annotation, (rom_type, values)) in izip!( + &cs.lk_expressions, + &cs.lk_expressions_namespace_map, + &cs.lk_expressions_items_map + ) { let lk_input = - (wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, expr) + (wit_infer_by_expr(&fixed, &witness, &[], &pi_mles, &challenges, expr) .get_ext_field_vec())[..num_rows] .to_vec(); - rom_inputs.entry(rom_type).or_default().push(( + rom_inputs.entry(*rom_type).or_default().push(( lk_input, circuit_name.clone(), - annotation, - values, + annotation.clone(), + values.clone(), )); } } else { @@ -833,19 +833,24 @@ Hints: num_rows ); // gather lookup tables - for (expr, (rom_type, _)) in cs - .lk_table_expressions - .iter() - .zip(cs.lk_expressions_items_map.clone().into_iter()) + for (expr, (rom_type, _)) in + izip!(&cs.lk_table_expressions, &cs.lk_expressions_items_map) { - let lk_table = - wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, &expr.values) - .get_ext_field_vec() - .to_vec(); + let lk_table = wit_infer_by_expr( + &fixed, + &witness, + &[], + &pi_mles, + &challenges, + &expr.values, + ) + .get_ext_field_vec() + .to_vec(); let multiplicity = wit_infer_by_expr( &fixed, &witness, + &[], &pi_mles, &challenges, &expr.multiplicity, @@ -856,11 +861,8 @@ Hints: assert!( rom_tables .insert( - rom_type, - lk_table - .into_iter() - .zip(multiplicity.into_iter()) - .collect::>(), + *rom_type, + izip!(lk_table, multiplicity).collect::>(), ) .is_none(), "cannot assign to rom table {:?} twice", @@ -910,6 +912,7 @@ Hints: eval_by_expr_with_instance( &[], &witness, + &[], &instance, challenges.as_slice(), expr, @@ -968,10 +971,16 @@ Hints: .zip_eq(cs.w_ram_types.iter()) .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { - let write_rlc_records = - (wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, w_rlc_expr) - .get_ext_field_vec())[..*num_rows] - .to_vec(); + let write_rlc_records = (wit_infer_by_expr( + fixed, + witness, + &[], + &pi_mles, + &challenges, + w_rlc_expr, + ) + .get_ext_field_vec())[..*num_rows] + .to_vec(); if $ram_type == RAMType::GlobalState { // w_exprs = [GlobalState, pc, timestamp] @@ -983,6 +992,7 @@ Hints: let v = wit_infer_by_expr( fixed, witness, + &[], &pi_mles, &challenges, expr, @@ -1031,7 +1041,7 @@ Hints: .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { let read_records = - wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, r_expr) + wit_infer_by_expr(fixed, witness, &[], &pi_mles, &challenges, r_expr) .get_ext_field_vec()[..*num_rows] .to_vec(); let mut records = vec![]; @@ -1140,6 +1150,7 @@ Hints: let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) = derive_ram_rws!(RAMType::GlobalState); gs_rs.insert(eval_by_expr_with_instance( + &[], &[], &[], &instance, @@ -1147,6 +1158,7 @@ Hints: &gs_final, )); gs_ws.insert(eval_by_expr_with_instance( + &[], &[], &[], &instance, @@ -1155,27 +1167,25 @@ Hints: )); // gs stores { (pc, timestamp) } - let gs_clone = gs.clone(); find_rw_mismatch!( gs_rs, rs_grp_by_anno, gs_ws, ws_grp_by_anno, RAMType::GlobalState, - gs_clone + gs ); // part2 registers let (reg_rs, rs_grp_by_anno, reg_ws, ws_grp_by_anno, _) = derive_ram_rws!(RAMType::Register); - let gs_clone = gs.clone(); find_rw_mismatch!( reg_rs, rs_grp_by_anno, reg_ws, ws_grp_by_anno, RAMType::Register, - gs_clone + gs ); // part3 memory diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c8cae8bc..daaa1099d 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -21,7 +21,6 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use crate::{ - circuit_builder::SetTableAddrType, error::ZKVMError, expression::Instance, scheme::{ @@ -93,6 +92,7 @@ impl> ZKVMProver { // commit to main traces let mut commitments = BTreeMap::new(); let mut wits = BTreeMap::new(); + let mut structural_wits = BTreeMap::new(); let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name @@ -103,20 +103,46 @@ impl> ZKVMProver { circuit_name = circuit_name, profiling_2 = true ); - let witness = match num_instances { - 0 => vec![], + let num_witin = self + .pk + .circuit_pks + .get(&circuit_name) + .unwrap() + .get_cs() + .num_witin; + + let (witness, structural_witness) = match num_instances { + 0 => (vec![], vec![]), _ => { - let witness = witness.into_mles(); + let mut witness = witness.into_mles(); + let structural_witness = witness.split_off(num_witin as usize); commitments.insert( circuit_name.clone(), PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) .map_err(ZKVMError::PCSError)?, ); - witness + + (witness, structural_witness) } }; exit_span!(span); - wits.insert(circuit_name, (witness, num_instances)); + wits.insert( + circuit_name.clone(), + ( + witness.into_iter().map(|w| w.into()).collect_vec(), + num_instances, + ), + ); + structural_wits.insert( + circuit_name, + ( + structural_witness + .into_iter() + .map(|v| v.into()) + .collect_vec(), + num_instances, + ), + ); } exit_span!(commit_to_traces_span); @@ -161,7 +187,7 @@ impl> ZKVMProver { circuit_name, &self.pk.pp, pk, - witness.into_iter().map(|w| w.into()).collect_vec(), + witness, wits_commit, &pi, num_instances, @@ -177,20 +203,25 @@ impl> ZKVMProver { .opcode_proofs .insert(circuit_name.clone(), (i, opcode_proof)); } else { + let (structural_witness, structural_num_instances) = structural_wits + .remove(circuit_name) + .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; let (table_proof, pi_in_evals) = self.create_table_proof( circuit_name, &self.pk.pp, pk, - witness.into_iter().map(|v| v.into()).collect_vec(), + witness, wits_commit, + structural_witness, &pi, transcript, &challenges, )?; tracing::info!( - "generated proof for table {} with num_instances={}", + "generated proof for table {} with num_instances={}, structural_num_instances={}", circuit_name, - num_instances + num_instances, + structural_num_instances ); vm_proof .table_proofs @@ -245,7 +276,7 @@ impl> ZKVMProver { .chain(cs.lk_expressions.par_iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr) + wit_infer_by_expr(&[], &witnesses, &[], pi, challenges, expr) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -525,7 +556,7 @@ impl> ZKVMProver { // sanity check in debug build and output != instance index for zero check sumcheck poly if cfg!(debug_assertions) { let expected_zero_poly = - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr); + wit_infer_by_expr(&[], &witnesses, &[], pi, challenges, expr); let top_100_errors = expected_zero_poly .get_base_field_vec() .iter() @@ -662,6 +693,7 @@ impl> ZKVMProver { circuit_pk: &ProvingKey, witnesses: Vec>, wits_commit: PCS::CommitmentWithWitness, + structural_witnesses: Vec>, pi: &[ArcMultilinearExtension<'_, E>], transcript: &mut impl Transcript, challenges: &[E; 2], @@ -679,6 +711,7 @@ impl> ZKVMProver { .unwrap_or_default(); // sanity check assert_eq!(witnesses.len(), cs.num_witin as usize); + assert_eq!(structural_witnesses.len(), cs.num_structural_witin as usize); assert_eq!(fixed.len(), cs.num_fixed); // check all witness size are power of 2 assert!( @@ -686,6 +719,11 @@ impl> ZKVMProver { .iter() .all(|v| { v.evaluations().len().is_power_of_two() }) ); + assert!( + structural_witnesses + .iter() + .all(|v| { v.evaluations().len().is_power_of_two() }) + ); assert!( !cs.r_table_expressions.is_empty() || !cs.w_table_expressions.is_empty() @@ -714,7 +752,14 @@ impl> ZKVMProver { .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr) + wit_infer_by_expr( + &fixed, + &witnesses, + &structural_witnesses, + pi, + challenges, + expr, + ) }) .collect(); let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap(); @@ -885,6 +930,15 @@ impl> ZKVMProver { }) .collect_vec(); + // (non uniform) collect dynamic address hints as witness for verifier + let rw_hints_num_vars = structural_witnesses + .iter() + .map(|mle| mle.num_vars()) + .collect_vec(); + for var in rw_hints_num_vars.iter() { + transcript.append_message(&var.to_le_bytes()); + } + let (rt_tower, tower_proof) = TowerProver::create_proof( // pattern [r1, w1, r2, w2, ...] same pair are chain together r_wit_layers @@ -1035,17 +1089,6 @@ impl> ZKVMProver { } exit_span!(span); - // (non uniform) collect dynamic address hints as witness for verifier - // for fix address, we just fill 0, as verifier will derive it from vk - let rw_hints_num_vars = izip!(&cs.r_table_expressions, r_set_wit.iter()) - .map(|(t, mle)| match t.table_spec.addr_type { - // for fixed address, prover - SetTableAddrType::FixedAddr => 0, - SetTableAddrType::DynamicAddr(_) => mle.num_vars(), - }) - .collect_vec(); - // TODO implement mechanism to skip commitment - let pcs_opening = entered_span!("pcs_opening"); let (fixed_opening_proof, _fixed_commit) = if !fixed.is_empty() { ( diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c8ec6453a..16746c9b1 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -236,6 +236,7 @@ pub(crate) fn infer_tower_product_witness( pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], + structual_witnesses: &[ArcMultilinearExtension<'a, E>], instance: &[ArcMultilinearExtension<'a, E>], challenges: &[E; N], expr: &Expression, @@ -243,6 +244,7 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( expr.evaluate_with_instance::>( &|f| fixed[f.0].clone(), &|witness_id| witnesses[witness_id as usize].clone(), + &|witness_id, _, _, _| structual_witnesses[witness_id as usize].clone(), &|i| instance[i.0].clone(), &|scalar| { let scalar: ArcMultilinearExtension = @@ -349,21 +351,24 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( pub(crate) fn eval_by_expr( witnesses: &[E], + structural_witnesses: &[E], challenges: &[E], expr: &Expression, ) -> E { - eval_by_expr_with_fixed(&[], witnesses, challenges, expr) + eval_by_expr_with_fixed(&[], witnesses, structural_witnesses, challenges, expr) } pub(crate) fn eval_by_expr_with_fixed( fixed: &[E], witnesses: &[E], + structural_witnesses: &[E], challenges: &[E], expr: &Expression, ) -> E { expr.evaluate::( &|f| fixed[f.0], &|witness_id| witnesses[witness_id as usize], + &|witness_id, _, _, _| structural_witnesses[witness_id as usize], &|scalar| scalar.into(), &|challenge_id, pow, scalar, offset| { // TODO cache challenge power to be acquired once for each power @@ -379,6 +384,7 @@ pub(crate) fn eval_by_expr_with_fixed( pub fn eval_by_expr_with_instance( fixed: &[E], witnesses: &[E], + structural_witnesses: &[E], instance: &[E], challenges: &[E], expr: &Expression, @@ -386,6 +392,7 @@ pub fn eval_by_expr_with_instance( expr.evaluate_with_instance::( &|f| fixed[f.0], &|witness_id| witnesses[witness_id as usize], + &|witness_id, _, _, _| structural_witnesses[witness_id as usize], &|i| instance[i.0], &|scalar| scalar.into(), &|challenge_id, pow, scalar, offset| { @@ -681,6 +688,7 @@ mod tests { ], &[], &[], + &[], &expr, ); res.get_base_field_vec(); @@ -710,6 +718,7 @@ mod tests { vec![B::from(3)].into_mle().into(), ], &[], + &[], &[E::ONE], &expr, ); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index e3cb38b2d..9a082c25f 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,10 +1,9 @@ use std::marker::PhantomData; use ark_std::iterable::Iterable; -use ceno_emul::WORD_SIZE; use ff_ext::ExtensionField; -use itertools::{Itertools, interleave, izip}; +use itertools::{Itertools, chain, interleave, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, @@ -15,9 +14,8 @@ use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::{ForkableTranscript, Transcript}; use crate::{ - circuit_builder::SetTableAddrType, error::ZKVMError, - expression::Instance, + expression::{Instance, StructuralWitIn}, instructions::{Instruction, riscv::ecall::HaltInstruction}, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, @@ -223,6 +221,7 @@ impl> ZKVMVerifier } let initial_global_state = eval_by_expr_with_instance( + &[], &[], &[], pi_evals, @@ -231,6 +230,7 @@ impl> ZKVMVerifier ); prod_w *= initial_global_state; let finalize_global_state = eval_by_expr_with_instance( + &[], &[], &[], pi_evals, @@ -433,6 +433,7 @@ impl> ZKVMVerifier * eval_by_expr_with_instance( &[], &proof.wits_in_evals, + &[], pi, challenges, expr, @@ -449,22 +450,18 @@ impl> ZKVMVerifier )); } // verify records (degree = 1) statement, thus no sumcheck - if cs - .r_expressions - .iter() - .chain(cs.w_expressions.iter()) - .chain(cs.lk_expressions.iter()) - .zip_eq( - proof.r_records_in_evals[..r_counts_per_instance] - .iter() - .chain(proof.w_records_in_evals[..w_counts_per_instance].iter()) - .chain(proof.lk_records_in_evals[..lk_counts_per_instance].iter()), + if izip!( + chain!(&cs.r_expressions, &cs.w_expressions, &cs.lk_expressions), + chain!( + &proof.r_records_in_evals[..r_counts_per_instance], + &proof.w_records_in_evals[..w_counts_per_instance], + &proof.lk_records_in_evals[..lk_counts_per_instance] ) - .any(|(expr, expected_evals)| { - eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) - != *expected_evals - }) - { + ) + .any(|(expr, expected_evals)| { + eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr) + != *expected_evals + }) { return Err(ZKVMError::VerifyError( "record evaluate != expected_evals".into(), )); @@ -472,7 +469,8 @@ impl> ZKVMVerifier // verify zero expression (degree = 1) statement, thus no sumcheck if cs.assert_zero_expressions.iter().any(|expr| { - eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO + eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr) + != E::ZERO }) { return Err(ZKVMError::VerifyError("zero expression != 0".into())); } @@ -516,42 +514,43 @@ impl> ZKVMVerifier .zip_eq(cs.w_table_expressions.iter()) .all(|(r, w)| r.table_spec.len == w.table_spec.len) ); - let is_skip_same_point_sumcheck = cs - .r_table_expressions - .iter() - .chain(cs.w_table_expressions.iter()) - .map(|rw| rw.table_spec.len) - .chain(cs.lk_table_expressions.iter().map(|lk| lk.table_len)) - .all_equal(); + // in table proof, we always skip same point sumcheck for now + // as tower sumcheck batch product argument/logup in same length + let is_skip_same_point_sumcheck = true; // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; - let expected_rounds = izip!( - // w_table_expression round match with r_table_expression so it fine to check either of them - &cs.r_table_expressions, - &proof.rw_hints_num_vars - ) - .flat_map(|(r, hint_num_vars)| match r.table_spec.addr_type { - // fixed address: get number of round from vk - SetTableAddrType::FixedAddr => { - let num_vars = ceil_log2(r.table_spec.len); + let expected_rounds = cs + .r_table_expressions + .iter() + .flat_map(|r| { + // iterate through structural witins and collect max round. + let num_vars = r.table_spec.len.map(ceil_log2).unwrap_or_else(|| { + r.table_spec + .structural_witins + .iter() + .map(|StructuralWitIn { id, max_len, .. }| { + let hint_num_vars = proof.rw_hints_num_vars[*id as usize]; + assert!((1 << hint_num_vars) <= *max_len); + hint_num_vars + }) + .max() + .unwrap() + }); [num_vars, num_vars] - } - // dynamic: respect prover hint - SetTableAddrType::DynamicAddr(_) => { - // check number of vars doesn't exceed max len defined in vk - // this is important to prevent address overlapping - assert!((1 << hint_num_vars) <= r.table_spec.len); - [*hint_num_vars, *hint_num_vars] - } - }) - .chain( - cs.lk_table_expressions - .iter() - .map(|l| ceil_log2(l.table_len)), - ) - .collect_vec(); + }) + .chain( + cs.lk_table_expressions + .iter() + .map(|l| ceil_log2(l.table_len)), + ) + .collect_vec(); + + for var in proof.rw_hints_num_vars.iter() { + transcript.append_message(&var.to_le_bytes()); + } + let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = TowerVerify::verify( @@ -689,6 +688,33 @@ impl> ZKVMVerifier [proof.rw_in_evals.to_vec(), proof.lk_in_evals.to_vec()].concat(), ) }; + + // evaluate structural witness from verifier succinctly + let structural_witnesses = cs + .r_table_expressions + .iter() + .flat_map(|set_table_expression| { + set_table_expression + .table_spec + .structural_witins + .iter() + .map( + |StructuralWitIn { + offset, + multi_factor, + .. + }| { + eval_wellform_address_vec( + *offset as u64, + *multi_factor as u64, + &input_opening_point, + ) + }, + ) + .collect_vec() + }) + .collect_vec(); + // verify records (degree = 1) statement, thus no sumcheck if interleave( &cs.r_table_expressions, // r @@ -705,6 +731,7 @@ impl> ZKVMVerifier eval_by_expr_with_instance( &proof.fixed_in_evals, &proof.wits_in_evals, + &structural_witnesses, pi, challenges, expr, @@ -715,26 +742,6 @@ impl> ZKVMVerifier )); } - // verify dynamic address evaluation succinctly - // TODO we can also skip their mpcs proof - for r_table in cs.r_table_expressions.iter() { - match &r_table.table_spec.addr_type { - SetTableAddrType::FixedAddr => (), - SetTableAddrType::DynamicAddr(spec) => { - let expected_eval = eval_wellform_address_vec( - spec.offset as u64, - WORD_SIZE as u64, - &input_opening_point, - ); - if expected_eval != proof.wits_in_evals[spec.addr_witin_id] { - return Err(ZKVMError::VerifyError( - "dynamic addr evaluate != expected_evals".into(), - )); - } - } - } - } - // assume public io is tiny vector, so we evaluate it directly without PCS for &Instance(idx) in cs.instance_name_map.keys() { let poly = raw_pi[idx].to_vec().into_mle(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index b006b10e1..8ae405ec9 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -309,15 +309,16 @@ impl ZKVMWitnesses { input: &TC::WitnessInput, ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_some()); - let cs = cs.get_cs(&TC::name()).unwrap(); let witness = TC::assign_instances( config, cs.num_witin as usize, + cs.num_structural_witin as usize, self.combined_lk_mlt.as_ref().unwrap(), input, )?; assert!(self.witnesses_tables.insert(TC::name(), witness).is_none()); + assert!(!self.witnesses_opcodes.contains_key(&TC::name())); Ok(()) diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index ef1a4fd68..8e736228d 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -33,6 +33,7 @@ pub trait TableCircuit { fn assign_instances( config: &Self::TableConfig, num_witin: usize, + num_structural_witin: usize, multiplicity: &[HashMap], input: &Self::WitnessInput, ) -> Result, ZKVMError>; diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index 852702ea7..45c3e123b 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -57,10 +57,11 @@ impl TableCircuit for OpsTableCircuit fn assign_instances( config: &Self::TableConfig, num_witin: usize, + num_structural_witin: usize, multiplicity: &[HashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[OP::ROM_TYPE as usize]; - config.assign_instances(num_witin, multiplicity, OP::len()) + config.assign_instances(num_witin, num_structural_witin, multiplicity, OP::len()) } } diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 3fe75d242..efe7c4de9 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -67,11 +67,15 @@ impl OpTableConfig { pub fn assign_instances( &self, num_witin: usize, + num_structural_witin: usize, multiplicity: &HashMap, length: usize, ) -> Result, ZKVMError> { - let mut witness = - RowMajorMatrix::::new(length, num_witin, InstancePaddingStrategy::Default); + let mut witness = RowMajorMatrix::::new( + length, + num_witin + num_structural_witin, + InstancePaddingStrategy::Default, + ); let mut mlts = vec![0; length]; for (idx, mlt) in multiplicity { diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 5695ee247..5a43af187 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -164,6 +164,7 @@ impl TableCircuit for ProgramTableCircuit { fn assign_instances( config: &Self::TableConfig, num_witin: usize, + num_structural_witin: usize, multiplicity: &[HashMap], program: &Program, ) -> Result, ZKVMError> { @@ -177,7 +178,7 @@ impl TableCircuit for ProgramTableCircuit { let mut witness = RowMajorMatrix::::new( config.program_size, - num_witin, + num_witin + num_structural_witin, InstancePaddingStrategy::Default, ); witness @@ -232,6 +233,7 @@ mod tests { let witness = ProgramTableCircuit::::assign_instances( &config, cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, &lkm, &program, ) diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 3080f68bf..ef4fff37c 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -93,11 +93,12 @@ impl TableCirc fn assign_instances( config: &Self::TableConfig, num_witin: usize, + num_structural_witin: usize, _multiplicity: &[HashMap], final_v: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - config.assign_instances(num_witin, final_v) + config.assign_instances(num_witin, num_structural_witin, final_v) } } @@ -140,11 +141,12 @@ impl TableCirc fn assign_instances( config: &Self::TableConfig, num_witin: usize, + num_structural_witin: usize, _multiplicity: &[HashMap], final_cycles: &[Cycle], ) -> Result, ZKVMError> { // assume returned table is well-formed including padding - config.assign_instances(num_witin, final_cycles) + config.assign_instances(num_witin, num_structural_witin, final_cycles) } } @@ -211,10 +213,11 @@ impl TableC fn assign_instances( config: &Self::TableConfig, num_witin: usize, + num_structural_witin: usize, _multiplicity: &[HashMap], final_v: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - config.assign_instances(num_witin, final_v) + config.assign_instances(num_witin, num_structural_witin, final_v) } } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 8a31633bc..7f6934cda 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -1,15 +1,15 @@ use std::{marker::PhantomData, sync::Arc}; -use ceno_emul::{Addr, Cycle}; +use ceno_emul::{Addr, Cycle, WORD_SIZE}; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use crate::{ - circuit_builder::{CircuitBuilder, DynamicAddr, SetTableAddrType, SetTableSpec}, + circuit_builder::{CircuitBuilder, SetTableSpec}, error::ZKVMError, - expression::{Expression, Fixed, ToExpr, WitIn}, + expression::{Expression, Fixed, StructuralWitIn, ToExpr, WitIn}, instructions::{ InstancePaddingStrategy, riscv::constants::{LIMB_BITS, LIMB_MASK}, @@ -82,8 +82,8 @@ impl NonVolatileTableConfig NonVolatileTableConfig NonVolatileTableConfig( &self, - num_witness: usize, + num_witin: usize, + num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result, ZKVMError> { let mut final_table = RowMajorMatrix::::new( NVRAM::len(&self.params), - num_witness, + num_witin + num_structural_witin, InstancePaddingStrategy::Default, ); @@ -226,8 +227,8 @@ impl PubIOTableConfig { || "init_table", NVRAM::RAM_TYPE, SetTableSpec { - addr_type: SetTableAddrType::FixedAddr, - len: NVRAM::len(&cb.params), + len: Some(NVRAM::len(&cb.params)), + structural_witins: vec![], }, init_table, )?; @@ -235,8 +236,8 @@ impl PubIOTableConfig { || "final_table", NVRAM::RAM_TYPE, SetTableSpec { - addr_type: SetTableAddrType::FixedAddr, - len: NVRAM::len(&cb.params), + len: Some(NVRAM::len(&cb.params)), + structural_witins: vec![], }, final_table, )?; @@ -277,12 +278,13 @@ impl PubIOTableConfig { /// TODO consider taking RowMajorMatrix as argument to save allocations. pub fn assign_instances( &self, - num_witness: usize, + num_witin: usize, + num_structural_witin: usize, final_cycles: &[Cycle], ) -> Result, ZKVMError> { let mut final_table = RowMajorMatrix::::new( NVRAM::len(&self.params), - num_witness, + num_witin + num_structural_witin, InstancePaddingStrategy::Default, ); @@ -302,7 +304,7 @@ impl PubIOTableConfig { /// dynamic address as witin, relied on augment of knowledge to prove address form #[derive(Clone, Debug)] pub struct DynVolatileRamTableConfig { - addr: WitIn, + addr: StructuralWitIn, final_v: Vec, final_cycle: WitIn, @@ -315,7 +317,13 @@ impl DynVolatileRamTableConfig pub fn construct_circuit( cb: &mut CircuitBuilder, ) -> Result { - let addr = cb.create_witin(|| "addr"); + let max_len = DVRAM::max_len(&cb.params); + let addr = cb.create_structural_witin( + || "addr", + max_len, + DVRAM::offset_addr(&cb.params), + WORD_SIZE, + ); let final_v = (0..DVRAM::V_LIMBS) .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) @@ -350,11 +358,8 @@ impl DynVolatileRamTableConfig || "init_table", DVRAM::RAM_TYPE, SetTableSpec { - addr_type: SetTableAddrType::DynamicAddr(DynamicAddr { - addr_witin_id: addr.id.into(), - offset: DVRAM::offset_addr(&cb.params), - }), - len: DVRAM::max_len(&cb.params), + len: None, + structural_witins: vec![addr], }, init_table, )?; @@ -362,11 +367,8 @@ impl DynVolatileRamTableConfig || "final_table", DVRAM::RAM_TYPE, SetTableSpec { - addr_type: SetTableAddrType::DynamicAddr(DynamicAddr { - addr_witin_id: addr.id.into(), - offset: DVRAM::offset_addr(&cb.params), - }), - len: DVRAM::max_len(&cb.params), + len: None, + structural_witins: vec![addr], }, final_table, )?; @@ -383,16 +385,21 @@ impl DynVolatileRamTableConfig /// TODO consider taking RowMajorMatrix as argument to save allocations. pub fn assign_instances( &self, - num_witness: usize, + num_witin: usize, + num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result, ZKVMError> { assert!(final_mem.len() <= DVRAM::max_len(&self.params)); assert!(DVRAM::max_len(&self.params).is_power_of_two()); + let offset_addr = StructuralWitIn { + id: self.addr.id + (num_witin as u16), + ..self.addr + }; + let params = self.params.clone(); - let addr_column = self.addr.id as u64; let padding_fn = move |row: u64, col: u64| { - if col == addr_column { + if col == offset_addr.id as u64 { DVRAM::addr(¶ms, row as usize) as u64 } else { 0u64 @@ -401,7 +408,7 @@ impl DynVolatileRamTableConfig let mut final_table = RowMajorMatrix::::new( final_mem.len(), - num_witness, + num_witin + num_structural_witin, InstancePaddingStrategy::Custom(Arc::new(padding_fn)), ); @@ -412,7 +419,6 @@ impl DynVolatileRamTableConfig .enumerate() .for_each(|(i, (row, rec))| { assert_eq!(rec.addr, DVRAM::addr(&self.params, i)); - set_val!(row, self.addr, rec.addr as u64); if self.final_v.len() == 1 { // Assign value directly. @@ -425,6 +431,8 @@ impl DynVolatileRamTableConfig }); } set_val!(row, self.final_cycle, rec.cycle); + + set_val!(row, offset_addr, rec.addr as u64); }); Ok(final_table) @@ -465,18 +473,23 @@ mod tests { value: 0, }) .collect_vec(); - let wit = - HintsCircuit::::assign_instances(&config, cb.cs.num_witin as usize, &lkm, &input) - .unwrap(); + let wit = HintsCircuit::::assign_instances( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &lkm, + &input, + ) + .unwrap(); let addr_column = cb .cs - .witin_namespace_map + .structural_witin_namespace_map .iter() .position(|name| name == "riscv/RAM_Memory_HintsTable/addr") .unwrap(); - let addr_padded_view = wit.column_padded(addr_column); + let addr_padded_view = wit.column_padded(addr_column + cb.cs.num_witin as usize); // Expect addresses to proceed consecutively inside the padding as well let expected = successors(Some(addr_padded_view[0]), |idx| { Some(*idx + F::from(WORD_SIZE as u64)) diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index d5fdf7363..83d8da017 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -50,10 +50,11 @@ impl TableCircuit for RangeTableCircuit fn assign_instances( config: &Self::TableConfig, num_witin: usize, + num_structural_witin: usize, multiplicity: &[HashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize]; - config.assign_instances(num_witin, multiplicity, RANGE::len()) + config.assign_instances(num_witin, num_structural_witin, multiplicity, RANGE::len()) } } diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 30937824c..6e7ebaee4 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -60,11 +60,15 @@ impl RangeTableConfig { pub fn assign_instances( &self, num_witin: usize, + num_structural_witin: usize, multiplicity: &HashMap, length: usize, ) -> Result, ZKVMError> { - let mut witness = - RowMajorMatrix::::new(length, num_witin, InstancePaddingStrategy::Default); + let mut witness = RowMajorMatrix::::new( + length, + num_witin + num_structural_witin, + InstancePaddingStrategy::Default, + ); let mut mlts = vec![0; length]; for (idx, mlt) in multiplicity { diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 3f742ed71..910bafa3c 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -485,13 +485,13 @@ mod tests { // verify let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); uint_c.expr().iter().zip(result).for_each(|(c, ret)| { - assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); + assert_eq!(eval_by_expr(&wit, &[], &challenges, c), E::from(ret)); }); // overflow if overflow { let carries = uint_c.carries.unwrap().last().unwrap().expr(); - assert_eq!(eval_by_expr(&wit, &challenges, &carries), E::ONE); + assert_eq!(eval_by_expr(&wit, &[], &challenges, &carries), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) @@ -660,13 +660,13 @@ mod tests { // verify let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); uint_c.expr().iter().zip(result).for_each(|(c, ret)| { - assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); + assert_eq!(eval_by_expr(&wit, &[], &challenges, c), E::from(ret)); }); // overflow if overflow { let overflow = uint_c.carries.unwrap().last().unwrap().expr(); - assert_eq!(eval_by_expr(&wit, &challenges, &overflow), E::ONE); + assert_eq!(eval_by_expr(&wit, &[], &challenges, &overflow), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index b3d32eab3..4a0bcbb51 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -96,6 +96,7 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { let monomial_terms = expr.evaluate( &|_| unreachable!(), &|witness_id| vec![(E::ONE, { vec![witness_id] })], + &|structural_witness_id, _, _, _| vec![(E::ONE, { vec![structural_witness_id] })], &|scalar| vec![(E::from(scalar), { vec![] })], &|challenge_id, pow, scalar, offset| { let challenge = challenges[challenge_id as usize];