Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseFold: all open functions accept Arc instead of DenseMultilinearExtension #563

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a1f1b9b
Refactor pcs tests by adding some utility functions.
yczhangsjtu Nov 5, 2024
eeec18d
Fix compilation errors.
yczhangsjtu Nov 5, 2024
8206cbc
Temporarily avoid a bug that will be fixed by other changes.
yczhangsjtu Nov 5, 2024
c7cdc59
Merge remote-tracking branch 'origin/master' into feat/basefold-refac…
yczhangsjtu Nov 5, 2024
ce8393f
Remove unnecessary cfg(test)
yczhangsjtu Nov 5, 2024
d925cbf
Remove an extra function.
yczhangsjtu Nov 5, 2024
c4d69ff
Cleanup use codes.
yczhangsjtu Nov 5, 2024
a71b2f5
Merge remote-tracking branch 'origin/master' into feat/basefold-refac…
yczhangsjtu Nov 5, 2024
d01ebda
Merge the two similar benchmarks.
yczhangsjtu Nov 5, 2024
dc6b599
Merge two benchmarks sharing similar codes.
yczhangsjtu Nov 5, 2024
7dfda12
Gate more test functions.
yczhangsjtu Nov 5, 2024
a6c6181
Refactor tests in basefold.rs.
yczhangsjtu Nov 6, 2024
3492d71
Merge branch 'feat/basefold-refactor-extract-0' into feat/basefold-re…
yczhangsjtu Nov 6, 2024
44c5c69
Use Arc in open APIs.
yczhangsjtu Nov 6, 2024
23b10df
Add a missing utility function.
yczhangsjtu Nov 6, 2024
e644740
Remove a redundant function that should be added in another PR.
yczhangsjtu Nov 6, 2024
b5924bf
Fix clippy.
yczhangsjtu Nov 6, 2024
d42f794
Fix clippy.
yczhangsjtu Nov 6, 2024
8d82732
Cargo fmt.
yczhangsjtu Nov 6, 2024
4883bf7
Fix Cargo.toml.
yczhangsjtu Nov 6, 2024
56fb6fe
Merge remote-tracking branch 'origin/master' into feat/basefold-refac…
yczhangsjtu Nov 28, 2024
1e80c7e
Update basefold.rs
yczhangsjtu Nov 28, 2024
2d60df4
Merge remote-tracking branch 'origin/master' into feat/basefold-refac…
yczhangsjtu Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mpcs/benches/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn bench_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>(
let eval = poly.evaluate(point.as_slice());
transcript.append_field_element_ext(&eval);
let transcript_for_bench = transcript.clone();
let poly = ArcMultilinearExtension::from(poly);
let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap();

group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| {
Expand Down Expand Up @@ -164,6 +165,10 @@ fn bench_batch_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>
let values: Vec<E> = evals.iter().map(Evaluation::value).copied().collect();
transcript.append_field_element_exts(values.as_slice());
let transcript_for_bench = transcript.clone();
let polys = polys
.iter()
.map(|poly| ArcMultilinearExtension::from(poly.clone()))
.collect::<Vec<_>>();
let proof =
Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap();

Expand Down Expand Up @@ -196,11 +201,7 @@ fn bench_batch_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>
&mut transcript,
);

let values: Vec<E> = evals
.iter()
.map(Evaluation::value)
.copied()
.collect::<Vec<E>>();
let values: Vec<E> = evals.iter().map(Evaluation::value).copied().collect();
transcript.append_field_element_exts(values.as_slice());

let backup_transcript = transcript.clone();
Expand Down
60 changes: 34 additions & 26 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
inner_product, inner_product_three, interpolate_field_type_over_boolean_hypercube,
},
expression::{Expression, Query, Rotation},
ext_to_usize,
ext_to_usize, field_type_to_ext_vec,
hash::{Digest, write_digest_to_transcript},
log2_strict,
merkle_tree::MerkleTree,
Expand All @@ -34,7 +34,6 @@ use query_phase::{
prover_query_phase, simple_batch_prover_query_phase, simple_batch_verifier_query_phase,
verifier_query_phase,
};
use std::{borrow::BorrowMut, ops::Deref};
pub use structure::BasefoldSpec;
use structure::{BasefoldProof, ProofQueriesResultWithMerklePath};
use transcript::Transcript;
Expand All @@ -51,7 +50,6 @@ use rayon::{
iter::IntoParallelIterator,
prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator},
};
use std::borrow::Cow;
pub use sumcheck::{one_level_eval_hc, one_level_interp_hc};

type SumCheck<F> = ClassicSumCheck<CoefficientsProver<F>>;
Expand Down Expand Up @@ -466,7 +464,7 @@ where
/// will panic.
fn open(
pp: &Self::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Self::CommitmentWithData,
point: &[E],
_eval: &E, // Opening does not need eval, except for sanity check
Expand All @@ -480,7 +478,7 @@ where
// the protocol won't work, and saves no verifier work anyway.
// In this case, simply return the evaluations as trivial proof.
if comm.is_trivial::<Spec>() {
return Ok(Self::Proof::trivial(vec![poly.evaluations.clone()]));
return Ok(Self::Proof::trivial(vec![poly.evaluations().clone()]));
}

assert!(comm.num_vars >= Spec::get_basecode_msg_size_log());
Expand All @@ -499,8 +497,8 @@ where
point,
comm,
transcript,
poly.num_vars,
poly.num_vars - Spec::get_basecode_msg_size_log(),
poly.num_vars(),
poly.num_vars() - Spec::get_basecode_msg_size_log(),
);

// 2. Query phase. ---------------------------------------
Expand Down Expand Up @@ -546,15 +544,15 @@ where
/// not very useful in ceno.
fn batch_open(
pp: &Self::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Self::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
transcript: &mut Transcript<E>,
) -> Result<Self::Proof, Error> {
let timer = start_timer!(|| "Basefold::batch_open");
let num_vars = polys.iter().map(|poly| poly.num_vars).max().unwrap();
let min_num_vars = polys.iter().map(|p| p.num_vars).min().unwrap();
let num_vars = polys.iter().map(|poly| poly.num_vars()).max().unwrap();
let min_num_vars = polys.iter().map(|p| p.num_vars()).min().unwrap();
assert!(min_num_vars >= Spec::get_basecode_msg_size_log());

comms.iter().for_each(|comm| {
Expand Down Expand Up @@ -603,28 +601,31 @@ where
let merged_polys = evals.iter().zip(poly_iter_ext(&eq_xt)).fold(
// This folding will generate a vector of |points| pairs of (scalar, polynomial)
// The polynomials are initialized to zero, and the scalars are initialized to one
vec![(E::ONE, Cow::<DenseMultilinearExtension<E>>::default()); points.len()],
vec![(E::ONE, Vec::<E>::new()); points.len()],
|mut merged_polys, (eval, eq_xt_i)| {
// For each polynomial to open, eval.point() specifies which point it is to be opened at.
if merged_polys[eval.point()].1.num_vars == 0 {
if merged_polys[eval.point()].1.is_empty() {
// If the accumulator for this point is still the zero polynomial,
// directly assign the random coefficient and the polynomial to open to
// this accumulator
merged_polys[eval.point()] = (eq_xt_i, Cow::Borrowed(&polys[eval.poly()]));
merged_polys[eval.point()] = (
eq_xt_i,
field_type_to_ext_vec(polys[eval.poly()].evaluations()),
);
} else {
// If the accumulator is unempty now, first force its scalar to 1, i.e.,
// make (scalar, polynomial) to (1, scalar * polynomial)
let coeff = merged_polys[eval.point()].0;
if coeff != E::ONE {
merged_polys[eval.point()].0 = E::ONE;
multiply_poly(merged_polys[eval.point()].1.to_mut().borrow_mut(), &coeff);
multiply_poly(&mut merged_polys[eval.point()].1, &coeff);
}
// Equivalent to merged_poly += poly * batch_coeff. Note that
// add_assign_mixed_with_coeff allows adding two polynomials with
// different variables, and the result has the same number of vars
// with the larger one of the two added polynomials.
add_polynomial_with_coeff(
merged_polys[eval.point()].1.to_mut().borrow_mut(),
&mut merged_polys[eval.point()].1,
&polys[eval.poly()],
&eq_xt_i,
);
Expand All @@ -642,18 +643,16 @@ where
.iter()
.zip(&points)
.map(|((scalar, poly), point)| {
inner_product(
&poly_iter_ext(poly).collect_vec(),
build_eq_x_r_vec(point).iter(),
) * scalar
* E::from(1 << (num_vars - poly.num_vars))
inner_product(poly, build_eq_x_r_vec(point).iter())
* scalar
* E::from(1 << (num_vars - log2_strict(poly.len())))
// When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube
})
.sum::<E>();
assert_eq!(expected_sum, target_sum);

merged_polys.iter().enumerate().for_each(|(i, (_, poly))| {
assert_eq!(points[i].len(), poly.num_vars);
assert_eq!(points[i].len(), log2_strict(poly.len()));
});
}

Expand All @@ -666,12 +665,17 @@ where
* scalar
})
.sum();
let sumcheck_polys: Vec<&DenseMultilinearExtension<E>> = merged_polys
let sumcheck_polys: Vec<DenseMultilinearExtension<E>> = merged_polys
.iter()
.map(|(_, poly)| poly.deref())
.map(|(_, poly)| {
DenseMultilinearExtension::from_evaluations_ext_vec(
log2_strict(poly.len()),
poly.clone(),
)
})
.collect_vec();
let virtual_poly =
VirtualPolynomial::new(&expression, sumcheck_polys, &[], points.as_slice());
VirtualPolynomial::new(&expression, sumcheck_polys.iter(), &[], points.as_slice());

let (challenges, merged_poly_evals, sumcheck_proof) =
SumCheck::prove(&(), num_vars, virtual_poly, target_sum, transcript)?;
Expand All @@ -695,7 +699,7 @@ where
if cfg!(feature = "sanity-check") {
let poly_evals = polys
.iter()
.map(|poly| poly.evaluate(&challenges[..poly.num_vars]))
.map(|poly| poly.evaluate(&challenges[..poly.num_vars()]))
.collect_vec();
let new_target_sum = inner_product(&poly_evals, &coeffs);
let desired_sum = merged_polys
Expand All @@ -705,7 +709,11 @@ where
.map(|(((scalar, poly), point), evals_from_sum_check)| {
assert_eq!(
evals_from_sum_check,
poly.evaluate(&challenges[..poly.num_vars])
DenseMultilinearExtension::from_evaluations_ext_vec(
log2_strict(poly.len()),
poly.clone()
)
.evaluate(&challenges[..log2_strict(poly.len())])
);
*scalar
* evals_from_sum_check
Expand Down
24 changes: 15 additions & 9 deletions mpcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub fn pcs_batch_commit_and_write<E: ExtensionField, Pcs: PolynomialCommitmentSc

pub fn pcs_open<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
pp: &Pcs::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Pcs::CommitmentWithData,
point: &[E],
eval: &E,
Expand All @@ -73,7 +73,7 @@ pub fn pcs_open<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(

pub fn pcs_batch_open<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
pp: &Pcs::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Pcs::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
Expand Down Expand Up @@ -162,7 +162,7 @@ pub trait PolynomialCommitmentScheme<E: ExtensionField>: Clone + Debug {

fn open(
pp: &Self::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Self::CommitmentWithData,
point: &[E],
eval: &E,
Expand All @@ -171,7 +171,7 @@ pub trait PolynomialCommitmentScheme<E: ExtensionField>: Clone + Debug {

fn batch_open(
pp: &Self::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Self::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
Expand Down Expand Up @@ -226,7 +226,7 @@ where
{
fn ni_open(
pp: &Self::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Self::CommitmentWithData,
point: &[E],
eval: &E,
Expand All @@ -237,7 +237,7 @@ where

fn ni_batch_open(
pp: &Self::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Self::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
Expand Down Expand Up @@ -323,17 +323,17 @@ use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
fn validate_input<E: ExtensionField>(
function: &str,
param_num_vars: usize,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
points: &[Vec<E>],
) -> Result<(), Error> {
let polys = polys.iter().collect_vec();
let points = points.iter().collect_vec();
for poly in polys.iter() {
if param_num_vars < poly.num_vars {
if param_num_vars < poly.num_vars() {
return Err(err_too_many_variates(
function,
param_num_vars,
poly.num_vars,
poly.num_vars(),
));
}
}
Expand Down Expand Up @@ -458,6 +458,7 @@ pub mod test_util {
let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap();
let point = get_point_from_challenge(num_vars, &mut transcript);
let eval = poly.evaluate(point.as_slice());
let poly = ArcMultilinearExtension::from(poly);
transcript.append_field_element_ext(&eval);

(
Expand Down Expand Up @@ -529,6 +530,11 @@ pub mod test_util {
.collect::<Vec<E>>();
transcript.append_field_element_exts(values.as_slice());

let polys = polys
.iter()
.map(|poly| ArcMultilinearExtension::from(poly.clone()))
.collect_vec();

let proof =
Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap();
(comms, evals, proof, transcript.read_challenge())
Expand Down
Loading