Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing the air constraint #136

Merged
merged 9 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 55 additions & 1 deletion hints/src/hints.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proofman_starks_lib_c::{
acc_hint_field_c, acc_mul_hint_fields_c, get_hint_field_c, get_hint_ids_by_name_c, mul_hint_fields_c,
print_expression_c, print_row_c, set_hint_field_c, VecU64Result,
print_expression_c, print_row_c, set_hint_field_c, update_airgroupvalue_c, VecU64Result,
};

use std::collections::HashMap;
Expand Down Expand Up @@ -818,6 +818,60 @@ pub fn acc_mul_hint_fields<F: Field>(
(slice[0], slice[1])
}

#[allow(clippy::too_many_arguments)]
pub fn update_airgroupvalue<F: Field>(
setup_ctx: &SetupCtx,
proof_ctx: &ProofCtx<F>,
air_instance: &mut AirInstance<F>,
hint_id: usize,
hint_field_airgroupvalue: &str,
hint_field_name1: &str,
hint_field_name2: &str,
options1: HintFieldOptions,
options2: HintFieldOptions,
add: bool,
) -> u64 {
let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id);

let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void;
let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void;

let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void;
let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void;

let steps_params = StepsParams {
trace: air_instance.get_trace_ptr() as *mut c_void,
pols: air_instance.get_buffer_ptr() as *mut c_void,
public_inputs: public_inputs_ptr,
challenges: challenges_ptr,
airgroup_values: air_instance.airgroup_values.as_ptr() as *mut c_void,
airvalues: air_instance.airvalues.as_ptr() as *mut c_void,
evals: air_instance.evals.as_ptr() as *mut c_void,
xdivxsub: std::ptr::null_mut(),
p_const_pols: const_pols_ptr,
p_const_tree: const_tree_ptr,
custom_commits: air_instance.get_custom_commits_ptr(),
};

let raw_ptr = update_airgroupvalue_c(
(&setup.p_setup).into(),
(&steps_params).into(),
hint_id as u64,
hint_field_airgroupvalue,
hint_field_name1,
hint_field_name2,
(&options1).into(),
(&options2).into(),
add,
);

let hint_ids_result = unsafe { Box::from_raw(raw_ptr as *mut VecU64Result) };

let slice = unsafe { std::slice::from_raw_parts(hint_ids_result.values, hint_ids_result.n_values as usize) };

slice[0]
}

pub fn get_hint_field<F: Field>(
setup_ctx: &SetupCtx,
proof_ctx: &ProofCtx<F>,
Expand Down
22 changes: 12 additions & 10 deletions pil2-components/lib/std/pil/std_prod.pil
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,19 @@ private function piop_gprod_air() {
air_den *= (gprod_assumes_sel[i] * (gprod_assumes[i] + std_gamma - 1) + 1);
}

/*
At this point, the constraint has been transformed to:
gprod === ('gprod * (1 - L1) + L1) * air_num / air_den
// At this point, the constraint has been transformed to:
// gprod === ('gprod * (1 - L1) + L1) * air_num / air_den
// check that the constraint is satisfied
gprod * air_den === ('gprod * (1 - _L1) + _L1) * air_num;


Now, we whould add the direct terms to the constraint:
gprod === ('gprod * (1 - L1) + L1) * air_num / air_den * ∏ⱼ (sⱼ·(eⱼ+ɣ-1)+1) / (sⱼ'·(eⱼ'+ɣ-1)+1)
/*
At the very last row, it should be satisfied that:
gprod_result === gprod * ∏ⱼ (sⱼ·(eⱼ+ɣ-1)+1) / (sⱼ'·(eⱼ'+ɣ-1)+1)
where all sⱼ,sⱼ',eⱼ,eⱼ' are field elements, for all j.

We rewrite it as:
gprod * air_den * ∏ⱼ (sⱼ'·(eⱼ'+ɣ-1)+1) === ('gprod * (1 - L1) + L1) * air_num * ∏ⱼ (sⱼ·(eⱼ+ɣ-1)+1)
gprod_result * ∏ⱼ (sⱼ'·(eⱼ'+ɣ-1)+1) - gprod * ∏ⱼ (sⱼ·(eⱼ+ɣ-1)+1) === 0
*/

expr direct_num = 1;
Expand All @@ -234,10 +237,9 @@ private function piop_gprod_air() {
direct_den *= (direct_gprod_assumes_sel[i] * (direct_gprod_assumes[i] + std_gamma - 1) + 1);
}

@gprod_col{reference: gprod, numerator: air_num * direct_num, denominator: air_den * direct_den, result: gprod_result};

gprod * air_den * direct_den === ('gprod * (1 - _L1) + _L1) * air_num * direct_num;
_L1' * (gprod - gprod_result) === 0;
@gprod_col{reference: gprod, numerator_air: air_num, denominator_air: air_den,
numerator_direct: air_num, denominator_direct: air_den, result: gprod_result};
_L1' * (gprod_result * direct_den - gprod * direct_num) === 0;
}

// Note: We don't "update" the prod at the airgroup level (i.e., all the resulting prods generated by each air)
Expand Down
23 changes: 11 additions & 12 deletions pil2-components/lib/std/pil/std_sum.pil
Original file line number Diff line number Diff line change
Expand Up @@ -355,16 +355,18 @@ private function piop_gsum_air(const int blowupFactor = 2) {
isolated_num = gsum_s[isolated_term];
}

/*
At this point, the constraint has been transformed to:
gsum === 'gsum * (1 - L1) + ∑ᵢ imᵢ + num / den
// At this point, the constraint has been transformed to:
// gsum === 'gsum * (1 - L1) + ∑ᵢ imᵢ + num / den
// check that the constraint is satisfied
(gsum - 'gsum * (1 - __L1) - sum_ims) * isolated_den - isolated_num === 0;

Now, we whould add the direct terms to the constraint:
gsum === 'gsum * (1 - L1) + ∑ᵢ imᵢ + num / den + ∑ⱼ sⱼ / (eⱼ + ɣ)
/*
At the very last row, it should be satisfied that:
gsum_result === gsum + ∑ⱼ sⱼ / (eⱼ + ɣ)
where both sⱼ and eⱼ are field elements, for all j.

We rewrite it as:
((gsum - 'gsum * (1 - L1) - ∑ᵢ imᵢ)·den - num)·∏ⱼ (eⱼ + ɣ) - den·∑ⱼ sⱼ·∏ₖ≠ⱼ (eₖ + ɣ) === 0
(gsum_result - gsum∏ⱼ (eⱼ + ɣ) - ∑ⱼ sⱼ·∏ₖ≠ⱼ (eₖ + ɣ) === 0
*/

// Compute the direct terms numerator and denominator
Expand All @@ -380,12 +382,9 @@ private function piop_gsum_air(const int blowupFactor = 2) {
direct_num += _tmp;
}

expr hint_den = direct_den * isolated_den;
expr hint_num = sum_ims * hint_den + direct_num * isolated_den + isolated_num * direct_den;
@gsum_col{reference: gsum, numerator: hint_num, denominator: hint_den, result: gsum_result};

((gsum - 'gsum * (1 - __L1) - sum_ims) * isolated_den - isolated_num) * direct_den - isolated_den * direct_num === 0;
__L1' * (gsum - gsum_result) === 0;
@gsum_col{reference: gsum, numerator_air: sum_ims * isolated_den + isolated_num, denominator_air: isolated_den,
numerator_direct: direct_num, denominator_direct: direct_den, result: gsum_result};
__L1' * ((gsum_result - gsum) * direct_den - direct_num) === 0;
}

// Note: We don't "update" the sum at the airgroup level (i.e., all the resulting sums generated by each air)
Expand Down
25 changes: 19 additions & 6 deletions pil2-components/lib/std/rs/src/std_prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use p3_field::PrimeField;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::{AirInstance, ExecutionCtx, ModeName, ProofCtx, SetupCtx, StdMode};
use proofman_hints::{
acc_mul_hint_fields, get_hint_field, get_hint_field_a, get_hint_field_constant, get_hint_field_constant_a,
get_hint_ids_by_name, HintFieldOptions, HintFieldOutput, HintFieldValue, HintFieldValuesVec,
get_hint_field, get_hint_field_a, get_hint_field_constant, get_hint_field_constant_a, get_hint_ids_by_name,
update_airgroupvalue, acc_mul_hint_fields, HintFieldOptions, HintFieldOutput, HintFieldValue, HintFieldValuesVec,
};

use crate::{print_debug_info, update_debug_data, DebugData, Decider};
Expand Down Expand Up @@ -294,22 +294,35 @@ impl<F: PrimeField> WitnessComponent<F> for StdProd<F> {

// This call calculates "numerator" / "denominator" and accumulates it into "reference". Its last value is stored into "result"
// Alternatively, this could be done using get_hint_field and set_hint_field methods and calculating the operations in Rust,
// TODO: GENERALIZE CALLS
let (pol_id, airgroupvalue_id) = acc_mul_hint_fields::<F>(
let (pol_id, _) = acc_mul_hint_fields::<F>(
&sctx,
&pctx,
air_instance,
gprod_hint,
"reference",
"result",
"numerator",
"denominator",
"numerator_air",
"denominator_air",
HintFieldOptions::default(),
HintFieldOptions::inverse(),
false,
);

air_instance.set_commit_calculated(pol_id as usize);

let airgroupvalue_id = update_airgroupvalue::<F>(
&sctx,
&pctx,
air_instance,
gprod_hint,
"result",
"numerator_direct",
"denominator_direct",
HintFieldOptions::default(),
HintFieldOptions::inverse(),
false,
);

air_instance.set_airgroupvalue_calculated(airgroupvalue_id as usize);
}
}
Expand Down
26 changes: 20 additions & 6 deletions pil2-components/lib/std/rs/src/std_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use p3_field::PrimeField;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx, StdMode, ModeName};
use proofman_hints::{
acc_mul_hint_fields, get_hint_field, get_hint_field_a, get_hint_field_constant, get_hint_field_constant_a,
get_hint_ids_by_name, mul_hint_fields, HintFieldOptions, HintFieldOutput, HintFieldValue, HintFieldValuesVec,
get_hint_field, get_hint_field_a, get_hint_field_constant, get_hint_field_constant_a, acc_mul_hint_fields,
update_airgroupvalue, get_hint_ids_by_name, mul_hint_fields, HintFieldOptions, HintFieldOutput, HintFieldValue,
HintFieldValuesVec,
};

use crate::{print_debug_info, update_debug_data, DebugData, Decider};
Expand Down Expand Up @@ -312,22 +313,35 @@ impl<F: PrimeField> WitnessComponent<F> for StdSum<F> {

// This call accumulates "expression" into "reference" expression and stores its last value to "result"
// Alternatively, this could be done using get_hint_field and set_hint_field methods and doing the accumulation in Rust,
// TODO: GENERALIZE CALLS
let (pol_id, airgroupvalue_id) = acc_mul_hint_fields::<F>(
let (pol_id, _) = acc_mul_hint_fields::<F>(
&sctx,
&pctx,
air_instance,
gsum_hint,
"reference",
"result",
"numerator",
"denominator",
"numerator_air",
"denominator_air",
HintFieldOptions::default(),
HintFieldOptions::inverse(),
true,
);

air_instance.set_commit_calculated(pol_id as usize);

let airgroupvalue_id = update_airgroupvalue::<F>(
&sctx,
&pctx,
air_instance,
gsum_hint,
"result",
"numerator_direct",
"denominator_direct",
HintFieldOptions::default(),
HintFieldOptions::inverse(),
true,
);

air_instance.set_airgroupvalue_calculated(airgroupvalue_id as usize);
}
}
Expand Down
3 changes: 2 additions & 1 deletion pil2-stark/lib/include/starks_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
void fri_proof_get_tree_root(void *pFriProof, void* root, uint64_t tree_index);
void fri_proof_set_airgroupvalues(void *pFriProof, void *airgroupValues);
void fri_proof_set_airvalues(void *pFriProof, void *airValues);
void *fri_proof_get_zkinproof(uint64_t proof_id, void *pFriProof, void* pPublics, void* pChallenges, void *pStarkInfo, char* globalInfoFile, char *fileDir);
void *fri_proof_get_zkinproof(void *pFriProof, void* pPublics, void* pChallenges, void *pStarkInfo, char* proof_name, char* globalInfoFile, char *fileDir);
void fri_proof_free_zkinproof(void *pZkinProof);
void fri_proof_free(void *pFriProof);

Expand Down Expand Up @@ -55,6 +55,7 @@
uint64_t mul_hint_fields(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameDest, char *hintFieldName1, char *hintFieldName2, void* hintOptions1, void *hintOptions2);
void *acc_hint_field(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameDest, char *hintFieldNameAirgroupVal, char *hintFieldName, bool add);
void *acc_mul_hint_fields(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameDest, char *hintFieldNameAirgroupVal, char *hintFieldName1, char *hintFieldName2, void* hintOptions1, void *hintOptions2, bool add);
void *update_airgroupvalue(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameAirgroupVal, char *hintFieldName1, char *hintFieldName2, void* hintOptions1, void *hintOptions2, bool add);
uint64_t set_hint_field(void *pSetupCtx, void* stepsParams, void *values, uint64_t hintId, char* hintFieldName);

// Starks
Expand Down
10 changes: 7 additions & 3 deletions pil2-stark/src/api/starks_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void fri_proof_set_airvalues(void *pFriProof, void *airValues)
FRIProof<Goldilocks::Element> *friProof = (FRIProof<Goldilocks::Element> *)pFriProof;
friProof->proof.setAirValues((Goldilocks::Element *)airValues);
}
void *fri_proof_get_zkinproof(uint64_t proof_id, void *pFriProof, void* pPublics, void* pChallenges, void *pStarkInfo, char* globalInfoFile, char *fileDir)
void *fri_proof_get_zkinproof(void *pFriProof, void* pPublics, void* pChallenges, void *pStarkInfo, char* proof_name, char* globalInfoFile, char *fileDir)
{
json globalInfo;
file2json(globalInfoFile, globalInfo);
Expand Down Expand Up @@ -119,8 +119,8 @@ void *fri_proof_get_zkinproof(uint64_t proof_id, void *pFriProof, void* pPublics
if (!std::filesystem::exists(string(fileDir) + "/proofs")) {
std::filesystem::create_directory(string(fileDir) + "/proofs");
}
json2file(jProof, string(fileDir) + "/proofs/proof_" + to_string(proof_id) + ".json");
json2file(zkin, string(fileDir) + "/zkin/proof_" + to_string(proof_id) + "_zkin.json");
json2file(jProof, string(fileDir) + "/proofs/proof_" + proof_name + ".json");
json2file(zkin, string(fileDir) + "/zkin/proof_" + proof_name + "_zkin.json");
}

return (void *) new nlohmann::json(zkin);
Expand Down Expand Up @@ -279,6 +279,10 @@ void *acc_mul_hint_fields(void *pSetupCtx, void* stepsParams, uint64_t hintId, c
return new VecU64Result(accMulHintFields(*(SetupCtx *)pSetupCtx, *(StepsParams *)stepsParams, hintId, string(hintFieldNameDest), string(hintFieldNameAirgroupVal), string(hintFieldName1), string(hintFieldName2),*(HintFieldOptions *)hintOptions1, *(HintFieldOptions *)hintOptions2, add));
}

void *update_airgroupvalue(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameAirgroupVal, char *hintFieldName1, char *hintFieldName2, void* hintOptions1, void *hintOptions2, bool add) {
return new VecU64Result(updateAirgroupValue(*(SetupCtx *)pSetupCtx, *(StepsParams *)stepsParams, hintId, string(hintFieldNameAirgroupVal), string(hintFieldName1), string(hintFieldName2),*(HintFieldOptions *)hintOptions1, *(HintFieldOptions *)hintOptions2, add));
}


uint64_t set_hint_field(void *pSetupCtx, void* params, void *values, uint64_t hintId, char * hintFieldName)
{
Expand Down
3 changes: 2 additions & 1 deletion pil2-stark/src/api/starks_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
void fri_proof_get_tree_root(void *pFriProof, void* root, uint64_t tree_index);
void fri_proof_set_airgroupvalues(void *pFriProof, void *airgroupValues);
void fri_proof_set_airvalues(void *pFriProof, void *airValues);
void *fri_proof_get_zkinproof(uint64_t proof_id, void *pFriProof, void* pPublics, void* pChallenges, void *pStarkInfo, char* globalInfoFile, char *fileDir);
void *fri_proof_get_zkinproof(void *pFriProof, void* pPublics, void* pChallenges, void *pStarkInfo, char* proof_name, char* globalInfoFile, char *fileDir);
void fri_proof_free_zkinproof(void *pZkinProof);
void fri_proof_free(void *pFriProof);

Expand Down Expand Up @@ -55,6 +55,7 @@
uint64_t mul_hint_fields(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameDest, char *hintFieldName1, char *hintFieldName2, void* hintOptions1, void *hintOptions2);
void *acc_hint_field(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameDest, char *hintFieldNameAirgroupVal, char *hintFieldName, bool add);
void *acc_mul_hint_fields(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameDest, char *hintFieldNameAirgroupVal, char *hintFieldName1, char *hintFieldName2, void* hintOptions1, void *hintOptions2, bool add);
void *update_airgroupvalue(void *pSetupCtx, void* stepsParams, uint64_t hintId, char *hintFieldNameAirgroupVal, char *hintFieldName1, char *hintFieldName2, void* hintOptions1, void *hintOptions2, bool add);
uint64_t set_hint_field(void *pSetupCtx, void* stepsParams, void *values, uint64_t hintId, char* hintFieldName);

// Starks
Expand Down
2 changes: 1 addition & 1 deletion pil2-stark/src/starkpil/expressions_avx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class ExpressionsAvx : public ExpressionsCtx {
copyPolynomial(&destVals[j][k*FIELD_EXTENSION], dests[j].params[k].inverse, dests[j].params[k].dim, &bufferT_[nColsStagesAcc[buffPos] + stagePos]);
continue;
} else if(dests[j].params[k].op == opType::number) {
uint64_t val = dests[j].params[k].inverse ? Goldilocks::inv(Goldilocks::fromU64(dests[j].params[k].value)).fe : dests[j].params[k].value;
uint64_t val = dests[j].params[k].value;
destVals[j][k*FIELD_EXTENSION] = _mm256_set1_epi64x(val);
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion pil2-stark/src/starkpil/expressions_avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class ExpressionsAvx512 : public ExpressionsCtx {
copyPolynomial(&destVals[j][k*FIELD_EXTENSION], dests[j].params[k].inverse, dests[j].params[k].dim, &bufferT_[nColsStagesAcc[buffPos] + stagePos]);
continue;
} else if(dests[j].params[k].op == opType::number) {
uint64_t val = dests[j].params[k].inverse ? Goldilocks::inv(Goldilocks::fromU64(dests[j].params[k].value)).fe : dests[j].params[k].value;
uint64_t val = dests[j].params[k].value;
destVals[j][k*FIELD_EXTENSION] = _mm512_set1_epi64(val);
continue;
}
Expand Down
Loading
Loading