diff --git a/pil2-components/lib/std/pil/std_connection.pil b/pil2-components/lib/std/pil/std_connection.pil index 9d735a56..bfa8dc5b 100644 --- a/pil2-components/lib/std/pil/std_connection.pil +++ b/pil2-components/lib/std/pil/std_connection.pil @@ -74,7 +74,7 @@ const int DEFAULT_CONNECTION_N = 0; * col witness a,b,c; * connection_init(opid, [a, b, c]); */ -function connection_init(const int opid, const expr cols[], int default_frame_size = DEFAULT_CONNECTION_N, const int bus_type = 0) { +function connection_init(const int opid, const expr cols[], int default_frame_size = DEFAULT_CONNECTION_N, const int bus_type = PIOP_BUS_DEFAULT) { if (default_frame_size == DEFAULT_CONNECTION_N) default_frame_size = N; if (default_frame_size < 1) { @@ -484,20 +484,6 @@ private function checkClosed() { } } -/** - * TODO - * - * @param {int} opid - The (unique) identifier of the connection - * @param {expr[]} cols - Array of columns to be connected - * @param {expr[]} conn - Fixed columns indicating the connection - * @example - * col witness a,b,c; - * col fixed S1,S2,S3; - * // Compute S1, S2, S3... - * connection(opid, [a, b, c], [S1, S2, S3]); - * connection(opid, [a, b, c], [S1, S2, S3], N/2); - */ - /** * Connects the columns `cols` with the fixed columns `CONN`. * @@ -510,7 +496,7 @@ private function checkClosed() { * col fixed S1,S2,S3; * connection(opid, [a, b, c], [S1, S2, S3]); */ -function connection(const int opid, const expr cols[], const expr CONN[], const int bus_type = 0) { +function connection(const int opid, const expr cols[], const expr CONN[], const int bus_type = PIOP_BUS_DEFAULT) { const int len = length(cols); if (len == 0) { error(`Connection #${opid} cannot be empty`); diff --git a/pil2-components/lib/std/rs/src/common.rs b/pil2-components/lib/std/rs/src/common.rs index effeadaf..6b66a230 100644 --- a/pil2-components/lib/std/rs/src/common.rs +++ b/pil2-components/lib/std/rs/src/common.rs @@ -4,7 +4,9 @@ use p3_field::PrimeField; use num_traits::ToPrimitive; use proofman::{get_hint_field_constant_gc, WitnessManager}; use proofman_common::{AirInstance, ProofCtx, SetupCtx, StdMode}; -use proofman_hints::{get_hint_field_constant, HintFieldOptions, HintFieldOutput, HintFieldValue}; +use proofman_hints::{ + get_hint_field_constant, get_hint_field_constant_a, HintFieldOptions, HintFieldOutput, HintFieldValue, +}; pub trait AirComponent { const MY_NAME: &'static str; @@ -62,6 +64,41 @@ pub fn get_hint_field_constant_as_field( } } +pub fn get_hint_field_constant_a_as_string( + sctx: &SetupCtx, + airgroup_id: usize, + air_id: usize, + hint_id: usize, + field_name: &str, + hint_field_options: HintFieldOptions, +) -> Vec { + let hint_field = get_hint_field_constant_a::(sctx, airgroup_id, air_id, hint_id, field_name, hint_field_options); + + let mut return_values: Vec = Vec::new(); + for (i, hint_field) in hint_field.iter().enumerate() { + match hint_field { + HintFieldValue::String(value) => return_values.push(value.clone()), + _ => panic!("Hint '{}' for field '{}' at position '{}' must be a string", hint_id, field_name, i), + } + } + + return_values +} + +pub fn get_hint_field_constant_as_string( + sctx: &SetupCtx, + airgroup_id: usize, + air_id: usize, + hint_id: usize, + field_name: &str, + hint_field_options: HintFieldOptions, +) -> String { + match get_hint_field_constant::(sctx, airgroup_id, air_id, hint_id, field_name, hint_field_options) { + HintFieldValue::String(value) => value, + _ => panic!("Hint '{}' for field '{}' must be a string", hint_id, field_name), + } +} + // Helper to extract a single field element as usize pub fn extract_field_element_as_usize(field: &HintFieldValue, name: &str) -> usize { let HintFieldValue::Field(field_value) = field else { diff --git a/pil2-components/lib/std/rs/src/debug.rs b/pil2-components/lib/std/rs/src/debug.rs index 171eb430..7f73443f 100644 --- a/pil2-components/lib/std/rs/src/debug.rs +++ b/pil2-components/lib/std/rs/src/debug.rs @@ -7,13 +7,14 @@ use std::{ }; use p3_field::PrimeField; +use proofman_common::ProofCtx; use proofman_hints::{format_vec, HintFieldOutput}; pub type DebugData = Mutex>, BusValue>>>; // opid -> val -> BusValue pub struct BusValue { shared_data: SharedData, // Data shared across all airgroups, airs, and instances - grouped_data: AirGroupMap, // Data grouped by: airgroup_id -> air_id -> instance_id -> MetaData + grouped_data: AirGroupMap, // Data grouped by: airgroup_id -> air_id -> instance_id -> InstanceData } struct SharedData { @@ -22,11 +23,18 @@ struct SharedData { num_assumes: F, } -type AirGroupMap = HashMap; -type AirIdMap = HashMap; -type InstanceMap = HashMap; +type AirGroupMap = HashMap; +type AirMap = HashMap; -struct MetaData { +struct AirData { + name_piop: String, + name_expr: Vec, + instances: InstanceMap, +} + +type InstanceMap = HashMap; + +struct InstanceData { row_proves: Vec, row_assumes: Vec, } @@ -34,6 +42,8 @@ struct MetaData { #[allow(clippy::too_many_arguments)] pub fn update_debug_data( debug_data: &DebugData, + name_piop: &String, + name_expr: &Vec, opid: F, val: Vec>, airgroup_id: usize, @@ -58,9 +68,14 @@ pub fn update_debug_data( .entry(airgroup_id) .or_default() .entry(air_id) - .or_default() + .or_insert_with(|| AirData { + name_piop: name_piop.clone(), + name_expr: name_expr.clone(), + instances: InstanceMap::new(), + }) + .instances .entry(instance_id) - .or_insert_with(|| MetaData { row_proves: Vec::new(), row_assumes: Vec::new() }); + .or_insert_with(|| InstanceData { row_proves: Vec::new(), row_assumes: Vec::new() }); // If the value is global but it was already processed, skip it if is_global { @@ -81,6 +96,7 @@ pub fn update_debug_data( } pub fn print_debug_info( + pctx: &ProofCtx, name: &str, max_values_to_print: usize, print_to_file: bool, @@ -151,7 +167,7 @@ pub fn print_debug_info( } let shared_data = &data.shared_data; let grouped_data = &mut data.grouped_data; - print_diffs(val, max_values_to_print, shared_data, grouped_data, false, &mut output); + print_diffs(pctx, val, max_values_to_print, shared_data, grouped_data, false, &mut output); } if len_overassumed > 0 { @@ -176,7 +192,7 @@ pub fn print_debug_info( let shared_data = &data.shared_data; let grouped_data = &mut data.grouped_data; - print_diffs(val, max_values_to_print, shared_data, grouped_data, true, &mut output); + print_diffs(pctx, val, max_values_to_print, shared_data, grouped_data, true, &mut output); } if len_overproven > 0 { @@ -185,6 +201,7 @@ pub fn print_debug_info( } fn print_diffs( + pctx: &ProofCtx, val: &[HintFieldOutput], max_values_to_print: usize, shared_data: &SharedData, @@ -211,8 +228,8 @@ pub fn print_debug_info( // Collect and organize rows let mut organized_rows = Vec::new(); for (airgroup_id, air_id_map) in grouped_data.iter_mut() { - for (air_id, instance_map) in air_id_map.iter_mut() { - for (instance_id, meta_data) in instance_map.iter_mut() { + for (air_id, air_data) in air_id_map.iter_mut() { + for (instance_id, meta_data) in air_data.instances.iter_mut() { let rows = { let rows = if proves { &meta_data.row_proves } else { &meta_data.row_assumes }; if rows.is_empty() { @@ -230,6 +247,11 @@ pub fn print_debug_info( // Print grouped rows for (airgroup_id, air_id, instance_id, mut rows) in organized_rows { + let airgroup_name = pctx.global_info.get_air_group_name(airgroup_id); + let air_name = pctx.global_info.get_air_name(airgroup_id, air_id); + let piop_name = &grouped_data.get(&airgroup_id).unwrap().get(&air_id).unwrap().name_piop; + let expr_name = &grouped_data.get(&airgroup_id).unwrap().get(&air_id).unwrap().name_expr; + rows.sort(); let rows_display = rows.iter().map(|x| x.to_string()).take(max_values_to_print).collect::>().join(","); @@ -237,15 +259,31 @@ pub fn print_debug_info( let truncated = rows.len() > max_values_to_print; writeln!( output, - "\t Airgroup: {:<3} | Air: {:<3} | Instance: {:<3} | Num: {:<9} | Rows: [{}{}]", - airgroup_id, - air_id, - instance_id, - rows.len(), - rows_display, - if truncated { ",..." } else { "" }, - ) - .expect("Write error"); + "\t - Airgroup: {} (id: {})", + airgroup_name, airgroup_id + ).expect("Write error"); + writeln!( + output, + "\t Air: {} (id: {})", + air_name, air_id + ).expect("Write error"); + + writeln!( + output, + "\t PIOP: {}", + piop_name + ).expect("Write error"); + writeln!( + output, + "\t Expression: {:?}", + expr_name + ).expect("Write error"); + + writeln!( + output, + "\t Instance ID: {} | Num: {} | Rows: [{}{}]", + instance_id, rows.len(), rows_display, if truncated { ",..." } else { "" } + ).expect("Write error"); } writeln!(output, "\t --------------------------------------------------").expect("Write error"); diff --git a/pil2-components/lib/std/rs/src/std_prod.rs b/pil2-components/lib/std/rs/src/std_prod.rs index ccaf548d..2af99ad7 100644 --- a/pil2-components/lib/std/rs/src/std_prod.rs +++ b/pil2-components/lib/std/rs/src/std_prod.rs @@ -8,16 +8,18 @@ use p3_field::PrimeField; use proofman::{get_hint_field_gc_constant_a, WitnessComponent, WitnessManager}; use proofman_common::{AirInstance, ExecutionCtx, ModeName, ProofCtx, SetupCtx, StdMode}; use proofman_hints::{ - get_hint_field, get_hint_field_a, get_hint_field_constant, get_hint_field_constant_a, acc_mul_hint_fields, - update_airgroupvalue, get_hint_ids_by_name, HintFieldOptions, HintFieldValue, HintFieldValuesVec, + get_hint_field, get_hint_field_a, acc_mul_hint_fields, update_airgroupvalue, get_hint_ids_by_name, + HintFieldOptions, HintFieldValue, HintFieldValuesVec, }; use crate::{ - extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_as_field, - get_row_field_value, print_debug_info, update_debug_data, AirComponent, DebugData, + extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_a_as_string, + get_hint_field_constant_as_field, get_hint_field_constant_as_string, get_row_field_value, print_debug_info, + update_debug_data, AirComponent, DebugData, }; pub struct StdProd { + pctx: Arc>, mode: StdMode, stage_wc: Option>, debug_data: Option>, @@ -40,6 +42,7 @@ impl AirComponent for StdProd { // Initialize std_prod with the extracted data let mode = mode.expect("Mode must be provided"); let std_prod = Arc::new(Self { + pctx: wcm.get_pctx(), mode: mode.clone(), stage_wc: match std_prod_users_id.is_empty() { true => None, @@ -75,7 +78,7 @@ impl AirComponent for StdProd { // Process each debug hint for &hint in debug_data_hints.iter() { // Extract hint fields - let _name_piop = get_hint_field_constant::( + let name_piop = get_hint_field_constant_as_string::( sctx, airgroup_id, air_id, @@ -84,7 +87,7 @@ impl AirComponent for StdProd { HintFieldOptions::default(), ); - let _name_expr = get_hint_field_constant_a::( + let name_expr = get_hint_field_constant_a_as_string::( sctx, airgroup_id, air_id, @@ -111,24 +114,10 @@ impl AirComponent for StdProd { HintFieldOptions::default(), ); - let proves = get_hint_field_constant_as_field::( - sctx, - airgroup_id, - air_id, - hint as usize, - "proves", - HintFieldOptions::default(), - ); - let proves = if proves.is_zero() { - false - } else if proves.is_one() { - true - } else { - log::error!("Proves hint must be either 0 or 1"); - panic!(); - }; - - let selector: HintFieldValue = + let proves = + get_hint_field::(sctx, pctx, air_instance, hint as usize, "proves", HintFieldOptions::default()); + + let sel: HintFieldValue = get_hint_field::(sctx, pctx, air_instance, hint as usize, "selector", HintFieldOptions::default()); let expressions = get_hint_field_a::( @@ -161,12 +150,14 @@ impl AirComponent for StdProd { // If both the expresion and the mul are of degree zero, then simply update the bus once if deg_expr.is_zero() && deg_sel.is_zero() { update_bus( + &name_piop, + &name_expr, airgroup_id, air_id, instance_id, opid, - proves, - &selector, + &proves, + &sel, &expressions, 0, debug_data, @@ -177,12 +168,14 @@ impl AirComponent for StdProd { else { for j in 0..num_rows { update_bus( + &name_piop, + &name_expr, airgroup_id, air_id, instance_id, opid, - proves, - &selector, + &proves, + &sel, &expressions, j, debug_data, @@ -193,24 +186,40 @@ impl AirComponent for StdProd { #[allow(clippy::too_many_arguments)] fn update_bus( + name_piop: &String, + name_expr: &Vec, airgroup_id: usize, air_id: usize, instance_id: usize, opid: F, - proves: bool, - selector: &HintFieldValue, + proves: &HintFieldValue, + sel: &HintFieldValue, expressions: &HintFieldValuesVec, row: usize, debug_data: &DebugData, is_global: bool, ) { - let selector = get_row_field_value(selector, row, "sel"); - if selector.is_zero() { + let mut sel = get_row_field_value(sel, row, "sel"); + if sel.is_zero() { return; } + let proves = match get_row_field_value(proves, row, "proves") { + p if p.is_zero() || p == F::neg_one() => { + // If it's an "assume", negate its value + if p == F::neg_one() { + sel = -sel; + } + false + } + p if p.is_one() => true, + _ => panic!("Proves hint must be either 0, 1, or -1"), + }; + update_debug_data( debug_data, + name_piop, + name_expr, opid, expressions.get(row), airgroup_id, @@ -218,7 +227,7 @@ impl AirComponent for StdProd { instance_id, row, proves, - F::one(), + sel, is_global, ); } @@ -318,17 +327,20 @@ impl WitnessComponent for StdProd { } } } + + // TODO: Process each direct update to the product bus } } fn end_proof(&self) { // Print debug info if in debug mode if self.mode.name == ModeName::Debug { + let pctx = &self.pctx; let name = Self::MY_NAME; let max_values_to_print = self.mode.n_vals; let print_to_file = self.mode.print_to_file; let debug_data = self.debug_data.as_ref().expect("Debug data missing"); - print_debug_info(name, max_values_to_print, print_to_file, debug_data); + print_debug_info(pctx, name, max_values_to_print, print_to_file, debug_data); } } } diff --git a/pil2-components/lib/std/rs/src/std_sum.rs b/pil2-components/lib/std/rs/src/std_sum.rs index 77b4ae8e..c3c3f136 100644 --- a/pil2-components/lib/std/rs/src/std_sum.rs +++ b/pil2-components/lib/std/rs/src/std_sum.rs @@ -9,17 +9,18 @@ use p3_field::PrimeField; use proofman::{get_hint_field_gc_constant_a, WitnessComponent, WitnessManager}; use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx, StdMode, ModeName}; use proofman_hints::{ - get_hint_field, get_hint_field_a, get_hint_field_constant, get_hint_field_constant_a, acc_mul_hint_fields, - update_airgroupvalue, get_hint_ids_by_name, mul_hint_fields, HintFieldOptions, HintFieldOutput, HintFieldValue, - HintFieldValuesVec, + get_hint_field, get_hint_field_a, acc_mul_hint_fields, update_airgroupvalue, get_hint_ids_by_name, mul_hint_fields, + HintFieldOptions, HintFieldOutput, HintFieldValue, HintFieldValuesVec, }; use crate::{ - extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_as_field, - get_row_field_value, print_debug_info, update_debug_data, AirComponent, DebugData, + extract_field_element_as_usize, get_global_hint_field_constant_as, get_hint_field_constant_a_as_string, + get_hint_field_constant_as_field, get_hint_field_constant_as_string, get_row_field_value, print_debug_info, + update_debug_data, AirComponent, DebugData, }; pub struct StdSum { + pctx: Arc>, mode: StdMode, stage_wc: Option>, debug_data: Option>, @@ -42,6 +43,7 @@ impl AirComponent for StdSum { // Initialize std_sum with the extracted data let mode = mode.expect("Mode must be provided"); let std_sum = Arc::new(Self { + pctx: wcm.get_pctx(), mode: mode.clone(), stage_wc: match std_sum_users_id.is_empty() { true => None, @@ -77,7 +79,7 @@ impl AirComponent for StdSum { // Process each debug hint for &hint in debug_data_hints.iter() { // Extract hint fields - let _name_piop = get_hint_field_constant::( + let name_piop = get_hint_field_constant_as_string::( sctx, airgroup_id, air_id, @@ -86,7 +88,7 @@ impl AirComponent for StdSum { HintFieldOptions::default(), ); - let _name_expr = get_hint_field_constant_a::( + let name_expr = get_hint_field_constant_a_as_string::( sctx, airgroup_id, air_id, @@ -157,6 +159,8 @@ impl AirComponent for StdSum { }; update_bus( + &name_piop, + &name_expr, airgroup_id, air_id, instance_id, @@ -189,6 +193,8 @@ impl AirComponent for StdSum { }; update_bus( + &name_piop, + &name_expr, airgroup_id, air_id, instance_id, @@ -206,6 +212,8 @@ impl AirComponent for StdSum { #[allow(clippy::too_many_arguments)] fn update_bus( + name_piop: &String, + name_expr: &Vec, airgroup_id: usize, air_id: usize, instance_id: usize, @@ -236,6 +244,8 @@ impl AirComponent for StdSum { update_debug_data( debug_data, + name_piop, + name_expr, opid, expressions.get(row), airgroup_id, @@ -360,17 +370,20 @@ impl WitnessComponent for StdSum { } } } + + // TODO: Process each direct update to the product bus } } fn end_proof(&self) { // Print debug info if in debug mode if self.mode.name == ModeName::Debug { + let pctx = &self.pctx; let name = Self::MY_NAME; let max_values_to_print = self.mode.n_vals; let print_to_file = self.mode.print_to_file; let debug_data = self.debug_data.as_ref().expect("Debug data missing"); - print_debug_info(name, max_values_to_print, print_to_file, debug_data); + print_debug_info(pctx, name, max_values_to_print, print_to_file, debug_data); } } } diff --git a/pil2-components/test/simple/rs/src/simple_left.rs b/pil2-components/test/simple/rs/src/simple_left.rs index ba705325..720e265e 100644 --- a/pil2-components/test/simple/rs/src/simple_left.rs +++ b/pil2-components/test/simple/rs/src/simple_left.rs @@ -4,7 +4,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx}; use p3_field::PrimeField; -use rand::{distributions::Standard, prelude::Distribution, seq::SliceRandom}; +use rand::{distributions::Standard, prelude::Distribution, seq::SliceRandom, Rng}; use crate::{SimpleLeftTrace, SIMPLE_AIRGROUP_ID, SIMPLE_LEFT_AIR_IDS}; @@ -77,8 +77,8 @@ where // Assumes for i in 0..num_rows { - trace[i].a = F::from_canonical_usize(i); - trace[i].b = F::from_canonical_usize(i); + trace[i].a = F::from_canonical_u64(rng.gen_range(0..=(1 << 63) - 1)); + trace[i].b = F::from_canonical_u64(rng.gen_range(0..=(1 << 63) - 1)); trace[i].e = F::from_canonical_u8(200); trace[i].f = F::from_canonical_u8(201); diff --git a/pil2-components/test/std/direct_update/direct_update.pil b/pil2-components/test/std/direct_update/direct_update.pil index b699364a..e6aff72f 100644 --- a/pil2-components/test/std/direct_update/direct_update.pil +++ b/pil2-components/test/std/direct_update/direct_update.pil @@ -36,8 +36,6 @@ airtemplate DirectUpdateProdGlobal(const int N = 2**4) { airgroup DirectUpdateProd { DirectUpdateProdLocal(); - DirectUpdateProdGlobal(); - public c_public[2]; proofval d_proofval_0; proofval d_proofval_1; @@ -45,6 +43,17 @@ airgroup DirectUpdateProd { proofval perform_global_update_1; direct_global_update_proves(OP_BUS_ID2, [OPID2, ...c_public, d_proofval_0, d_proofval_1], sel: perform_global_update_0, bus_type: PIOP_BUS_PROD); direct_global_update_proves(OP_BUS_ID2, [OPID2, ...c_public, d_proofval_0, d_proofval_1], sel: perform_global_update_1, bus_type: PIOP_BUS_PROD); + + DirectUpdateProdGlobal(); + + // TODO: Uncommented when compiler bug is fixed + // public c_public[2]; + // proofval d_proofval_0; + // proofval d_proofval_1; + // proofval perform_global_update_0; + // proofval perform_global_update_1; + // direct_global_update_proves(OP_BUS_ID2, [OPID2, ...c_public, d_proofval_0, d_proofval_1], sel: perform_global_update_0, bus_type: PIOP_BUS_PROD); + // direct_global_update_proves(OP_BUS_ID2, [OPID2, ...c_public, d_proofval_0, d_proofval_1], sel: perform_global_update_1, bus_type: PIOP_BUS_PROD); } const int OP_BUS_ID3 = 300; @@ -68,7 +77,7 @@ airtemplate DirectUpdateSumLocal(const int N = 2**5) { } const int OP_BUS_ID4 = 400; -const int OPID4 = 555; +const int OPID4 = 666; airtemplate DirectUpdateSumGlobal(const int N = 2**5) { col witness c[2],d[2];