Skip to content

Commit

Permalink
fixing the product bus
Browse files Browse the repository at this point in the history
  • Loading branch information
hecmas committed Dec 16, 2024
1 parent a03b9ac commit f9621eb
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 57 deletions.
2 changes: 1 addition & 1 deletion pil2-components/lib/std/pil/std_prod.pil
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ private function piop_gprod_air() {
}

@gprod_col{reference: gprod, numerator_air: air_num, denominator_air: air_den,
numerator_direct: air_num, denominator_direct: air_den, result: gprod_result};
numerator_direct: direct_num, denominator_direct: direct_den, result: gprod_result};
LLAST * (gprod_result * direct_den - gprod * direct_num) === 0;

// Store the airgroup and air ids for global hints
Expand Down
4 changes: 2 additions & 2 deletions pil2-components/lib/std/rs/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use proofman_common::SetupCtx;
use proofman_hints::{get_hint_field_constant, HintFieldOptions, HintFieldOutput, HintFieldValue};

// Helper to extract hint fields
pub fn get_global_hint_field_as<T, F>(sctx: Arc<SetupCtx>, hint_id: u64, field_name: &str) -> T
pub fn get_global_hint_field_constant_as<T, F>(sctx: Arc<SetupCtx>, hint_id: u64, field_name: &str) -> T
where
T: TryFrom<u64>,
T::Error: std::fmt::Debug,
Expand All @@ -27,7 +27,7 @@ where
.expect(&format!("Cannot convert value to {}", std::any::type_name::<T>()))
}

pub fn get_hint_field_as_field<F: PrimeField>(
pub fn get_hint_field_constant_as_field<F: PrimeField>(
sctx: &SetupCtx,
airgroup_id: usize,
air_id: usize,
Expand Down
61 changes: 17 additions & 44 deletions pil2-components/lib/std/rs/src/std_prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@ use std::{
sync::{Arc, Mutex},
};

use num_traits::ToPrimitive;
use p3_field::PrimeField;

use proofman::{get_hint_field_gc_constant_a, WitnessComponent, WitnessManager};
use proofman_common::{AirInstance, ExecutionCtx, ModeName, ProofCtx, SetupCtx, StdMode};
use proofman_hints::{
get_hint_field, get_hint_field_a, get_hint_field_constant, get_hint_field_constant_a, acc_mul_hint_fields,
update_airgroupvalue, get_hint_ids_by_name, HintFieldOptions, HintFieldOutput, HintFieldValue, HintFieldValuesVec,
update_airgroupvalue, get_hint_ids_by_name, HintFieldOptions, HintFieldValue, HintFieldValuesVec,
};

use crate::{
print_debug_info, update_debug_data, DebugData, get_global_hint_field_as, get_hint_field_as_field,
print_debug_info, update_debug_data, DebugData, get_global_hint_field_constant_as, get_hint_field_constant_as_field,
get_row_field_value, extract_field_element_as_usize,
};

Expand All @@ -40,9 +39,9 @@ impl<F: PrimeField> StdProd<F> {
true => None,
false => {
// Get the "stage_wc" hint
let stage_wc = get_global_hint_field_as::<u32, F>(sctx.clone(), std_prod_users_id[0], "stage_wc");
let stage_wc = get_global_hint_field_constant_as::<u32, F>(sctx.clone(), std_prod_users_id[0], "stage_wc");
Some(Mutex::new(stage_wc))
},
}
},
debug_data: if mode.name == ModeName::Debug { Some(Mutex::new(HashMap::new())) } else { None },
});
Expand Down Expand Up @@ -85,10 +84,16 @@ impl<F: PrimeField> StdProd<F> {
HintFieldOptions::default(),
);

let busid =
get_hint_field::<F>(sctx, pctx, air_instance, hint as usize, "busid", HintFieldOptions::default());
let opid = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
hint as usize,
"busid",
HintFieldOptions::default(),
);

let is_global = get_hint_field_as_field::<F>(
let is_global = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
Expand All @@ -97,7 +102,7 @@ impl<F: PrimeField> StdProd<F> {
HintFieldOptions::default(),
);

let proves = get_hint_field_as_field::<F>(
let proves = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
Expand Down Expand Up @@ -126,7 +131,7 @@ impl<F: PrimeField> StdProd<F> {
HintFieldOptions::default(),
);

let deg_expr = get_hint_field_as_field::<F>(
let deg_expr = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
Expand All @@ -135,7 +140,7 @@ impl<F: PrimeField> StdProd<F> {
HintFieldOptions::default(),
);

let deg_sel = get_hint_field_as_field::<F>(
let deg_sel = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
Expand All @@ -146,20 +151,6 @@ impl<F: PrimeField> StdProd<F> {

// If both the expresion and the mul are of degree zero, then simply update the bus once
if deg_expr.is_zero() && deg_sel.is_zero() {
// In this case, the busid must be a field element
let opid = match busid {
HintFieldValue::Field(opid) => {
// If opids are specified, then only update the bus if the opid is in the list
if let Some(opids) = &self.mode.opids {
if !opids.contains(&opid.as_canonical_biguint().to_u64().expect("Cannot convert to u64")) {
continue;
}
}
opid
}
_ => panic!("busid must be a field element"),
};

update_bus(
airgroup_id,
air_id,
Expand All @@ -176,22 +167,6 @@ impl<F: PrimeField> StdProd<F> {
// Otherwise, update the bus for each row
else {
for j in 0..num_rows {
// Get the opid for this row
let opid = match busid.get(j) {
HintFieldOutput::Field(opid) => {
// If opids are specified, then only update the bus if the opid is in the list
if let Some(opids) = &self.mode.opids {
if !opids
.contains(&opid.as_canonical_biguint().to_u64().expect("Cannot convert to u64"))
{
continue;
}
}
opid
}
_ => panic!("busid must be a field element"),
};

update_bus(
airgroup_id,
air_id,
Expand Down Expand Up @@ -260,7 +235,7 @@ impl<F: PrimeField> WitnessComponent<F> for StdProd<F> {
// Get the number of product check users and their airgroup and air IDs
let std_prod_users = get_hint_ids_by_name(sctx.get_global_bin(), "std_prod_users")[0];

let num_users = get_global_hint_field_as::<usize, F>(sctx.clone(), std_prod_users, "num_users");
let num_users = get_global_hint_field_constant_as::<usize, F>(sctx.clone(), std_prod_users, "num_users");
let airgroup_ids = get_hint_field_gc_constant_a::<F>(sctx.clone(), std_prod_users, "airgroup_ids", false);
let air_ids = get_hint_field_gc_constant_a::<F>(sctx.clone(), std_prod_users, "air_ids", false);

Expand Down Expand Up @@ -316,7 +291,6 @@ impl<F: PrimeField> WitnessComponent<F> for StdProd<F> {
HintFieldOptions::inverse(),
false,
);

air_instance.set_commit_calculated(pol_id as usize);

let airgroupvalue_id = update_airgroupvalue::<F>(
Expand All @@ -331,7 +305,6 @@ impl<F: PrimeField> WitnessComponent<F> for StdProd<F> {
HintFieldOptions::inverse(),
false,
);

air_instance.set_airgroupvalue_calculated(airgroupvalue_id as usize);
}
}
Expand Down
14 changes: 7 additions & 7 deletions pil2-components/lib/std/rs/src/std_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use proofman_hints::{
};

use crate::{
print_debug_info, update_debug_data, DebugData, get_global_hint_field_as, get_hint_field_as_field,
print_debug_info, update_debug_data, DebugData, get_global_hint_field_constant_as, get_hint_field_constant_as_field,
get_row_field_value, extract_field_element_as_usize,
};

Expand All @@ -41,9 +41,9 @@ impl<F: PrimeField> StdSum<F> {
true => None,
false => {
// Get the "stage_wc" hint
let stage_wc = get_global_hint_field_as::<u32, F>(sctx.clone(), std_sum_users_id[0], "stage_wc");
let stage_wc = get_global_hint_field_constant_as::<u32, F>(sctx.clone(), std_sum_users_id[0], "stage_wc");
Some(Mutex::new(stage_wc))
},
}
},
debug_data: if mode.name == ModeName::Debug { Some(Mutex::new(HashMap::new())) } else { None },
});
Expand Down Expand Up @@ -89,7 +89,7 @@ impl<F: PrimeField> StdSum<F> {
let busid =
get_hint_field::<F>(sctx, pctx, air_instance, hint as usize, "busid", HintFieldOptions::default());

let is_global = get_hint_field_as_field::<F>(
let is_global = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
Expand All @@ -113,7 +113,7 @@ impl<F: PrimeField> StdSum<F> {
HintFieldOptions::default(),
);

let deg_expr = get_hint_field_as_field::<F>(
let deg_expr = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
Expand All @@ -122,7 +122,7 @@ impl<F: PrimeField> StdSum<F> {
HintFieldOptions::default(),
);

let deg_mul = get_hint_field_as_field::<F>(
let deg_mul = get_hint_field_constant_as_field::<F>(
sctx,
airgroup_id,
air_id,
Expand Down Expand Up @@ -259,7 +259,7 @@ impl<F: PrimeField> WitnessComponent<F> for StdSum<F> {
// Get the number of sum check users and their airgroup and air IDs
let std_sum_users = get_hint_ids_by_name(sctx.get_global_bin(), "std_sum_users")[0];

let num_users = get_global_hint_field_as::<usize, F>(sctx.clone(), std_sum_users, "num_users");
let num_users = get_global_hint_field_constant_as::<usize, F>(sctx.clone(), std_sum_users, "num_users");
let airgroup_ids = get_hint_field_gc_constant_a::<F>(sctx.clone(), std_sum_users, "airgroup_ids", false);
let air_ids = get_hint_field_gc_constant_a::<F>(sctx.clone(), std_sum_users, "air_ids", false);

Expand Down
5 changes: 4 additions & 1 deletion pil2-components/test/std/connection/connection.pil
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require "goldilocks.pil"

require "std_constants.pil"
require "gl_groups_small.pil";
// require "gl_groups_big.pil";

Expand All @@ -9,6 +10,8 @@ const int TEST_OPID = 44;
// TODO: Finish Connection2 and Connection3, compute examples of permutations
// TODO: Add examples combining both approaches

set_bus_type(PIOP_BUS_PROD);

airtemplate Connection1(const int N = 2**3) {
if (N != 2**3) error(`Unsupported N = ${N}`);

Expand Down
5 changes: 3 additions & 2 deletions pil2-components/test/std/permutation/permutation.pil
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require "std_permutation.pil";
require "std_constants.pil";

airtemplate Permutation1(const int N = 2**6) {

Expand All @@ -14,7 +15,7 @@ airtemplate Permutation1(const int N = 2**6) {
permutation_assumes(3, [a3, b3], sel1);
permutation_proves(3, [c2, d2], sel2);

permutation_assumes(4, [a4, b4], sel3);
permutation_assumes(4, [a4, b4], sel3, bus_type: PIOP_BUS_PROD);
};

airtemplate Permutation2(const int N = 2**9) {
Expand All @@ -24,7 +25,7 @@ airtemplate Permutation2(const int N = 2**9) {

permutation_proves(2, [c1, d1]);

permutation_proves(4, [c2, d2], sel);
permutation_proves(4, [c2, d2], sel, bus_type: PIOP_BUS_PROD);
};

airgroup Permutation {
Expand Down

0 comments on commit f9621eb

Please sign in to comment.