Skip to content

Commit

Permalink
New features to debug mode
Browse files Browse the repository at this point in the history
  • Loading branch information
hecmas committed Dec 20, 2024
1 parent b413660 commit 6855cd0
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 85 deletions.
18 changes: 2 additions & 16 deletions pil2-components/lib/std/pil/std_connection.pil
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ const int DEFAULT_CONNECTION_N = 0;
* col witness a,b,c;
* connection_init(opid, [a, b, c]);
*/
function connection_init(const int opid, const expr cols[], int default_frame_size = DEFAULT_CONNECTION_N, const int bus_type = 0) {
function connection_init(const int opid, const expr cols[], int default_frame_size = DEFAULT_CONNECTION_N, const int bus_type = PIOP_BUS_DEFAULT) {
if (default_frame_size == DEFAULT_CONNECTION_N) default_frame_size = N;

if (default_frame_size < 1) {
Expand Down Expand Up @@ -484,20 +484,6 @@ private function checkClosed() {
}
}

/**
* TODO
*
* @param {int} opid - The (unique) identifier of the connection
* @param {expr[]} cols - Array of columns to be connected
* @param {expr[]} conn - Fixed columns indicating the connection
* @example
* col witness a,b,c;
* col fixed S1,S2,S3;
* // Compute S1, S2, S3...
* connection(opid, [a, b, c], [S1, S2, S3]);
* connection(opid, [a, b, c], [S1, S2, S3], N/2);
*/

/**
* Connects the columns `cols` with the fixed columns `CONN`.
*
Expand All @@ -510,7 +496,7 @@ private function checkClosed() {
* col fixed S1,S2,S3;
* connection(opid, [a, b, c], [S1, S2, S3]);
*/
function connection(const int opid, const expr cols[], const expr CONN[], const int bus_type = 0) {
function connection(const int opid, const expr cols[], const expr CONN[], const int bus_type = PIOP_BUS_DEFAULT) {
const int len = length(cols);
if (len == 0) {
error(`Connection #${opid} cannot be empty`);
Expand Down
39 changes: 38 additions & 1 deletion pil2-components/lib/std/rs/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use p3_field::PrimeField;
use num_traits::ToPrimitive;
use proofman::{get_hint_field_constant_gc, WitnessManager};
use proofman_common::{AirInstance, ProofCtx, SetupCtx, StdMode};
use proofman_hints::{get_hint_field_constant, HintFieldOptions, HintFieldOutput, HintFieldValue};
use proofman_hints::{
get_hint_field_constant, get_hint_field_constant_a, HintFieldOptions, HintFieldOutput, HintFieldValue,
};

pub trait AirComponent<F> {
const MY_NAME: &'static str;
Expand Down Expand Up @@ -62,6 +64,41 @@ pub fn get_hint_field_constant_as_field<F: PrimeField>(
}
}

pub fn get_hint_field_constant_a_as_string<F: PrimeField>(
sctx: &SetupCtx,
airgroup_id: usize,
air_id: usize,
hint_id: usize,
field_name: &str,
hint_field_options: HintFieldOptions,
) -> Vec<String> {
let hint_field = get_hint_field_constant_a::<F>(sctx, airgroup_id, air_id, hint_id, field_name, hint_field_options);

let mut return_values: Vec<String> = Vec::new();
for (i, hint_field) in hint_field.iter().enumerate() {
match hint_field {
HintFieldValue::String(value) => return_values.push(value.clone()),
_ => panic!("Hint '{}' for field '{}' at position '{}' must be a string", hint_id, field_name, i),
}
}

return_values
}

pub fn get_hint_field_constant_as_string<F: PrimeField>(
sctx: &SetupCtx,
airgroup_id: usize,
air_id: usize,
hint_id: usize,
field_name: &str,
hint_field_options: HintFieldOptions,
) -> String {
match get_hint_field_constant::<F>(sctx, airgroup_id, air_id, hint_id, field_name, hint_field_options) {
HintFieldValue::String(value) => value,
_ => panic!("Hint '{}' for field '{}' must be a string", hint_id, field_name),
}
}

// Helper to extract a single field element as usize
pub fn extract_field_element_as_usize<F: PrimeField>(field: &HintFieldValue<F>, name: &str) -> usize {
let HintFieldValue::Field(field_value) = field else {
Expand Down
78 changes: 58 additions & 20 deletions pil2-components/lib/std/rs/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ use std::{
};

use p3_field::PrimeField;
use proofman_common::ProofCtx;
use proofman_hints::{format_vec, HintFieldOutput};

pub type DebugData<F> = Mutex<HashMap<F, HashMap<Vec<HintFieldOutput<F>>, BusValue<F>>>>; // opid -> val -> BusValue

pub struct BusValue<F> {
shared_data: SharedData<F>, // Data shared across all airgroups, airs, and instances
grouped_data: AirGroupMap, // Data grouped by: airgroup_id -> air_id -> instance_id -> MetaData
grouped_data: AirGroupMap, // Data grouped by: airgroup_id -> air_id -> instance_id -> InstanceData
}

struct SharedData<F> {
Expand All @@ -22,18 +23,27 @@ struct SharedData<F> {
num_assumes: F,
}

type AirGroupMap = HashMap<usize, AirIdMap>;
type AirIdMap = HashMap<usize, InstanceMap>;
type InstanceMap = HashMap<usize, MetaData>;
type AirGroupMap = HashMap<usize, AirMap>;
type AirMap = HashMap<usize, AirData>;

struct MetaData {
struct AirData {
name_piop: String,
name_expr: Vec<String>,
instances: InstanceMap,
}

type InstanceMap = HashMap<usize, InstanceData>;

struct InstanceData {
row_proves: Vec<usize>,
row_assumes: Vec<usize>,
}

#[allow(clippy::too_many_arguments)]
pub fn update_debug_data<F: PrimeField>(
debug_data: &DebugData<F>,
name_piop: &String,
name_expr: &Vec<String>,
opid: F,
val: Vec<HintFieldOutput<F>>,
airgroup_id: usize,
Expand All @@ -58,9 +68,14 @@ pub fn update_debug_data<F: PrimeField>(
.entry(airgroup_id)
.or_default()
.entry(air_id)
.or_default()
.or_insert_with(|| AirData {
name_piop: name_piop.clone(),
name_expr: name_expr.clone(),
instances: InstanceMap::new(),
})
.instances
.entry(instance_id)
.or_insert_with(|| MetaData { row_proves: Vec::new(), row_assumes: Vec::new() });
.or_insert_with(|| InstanceData { row_proves: Vec::new(), row_assumes: Vec::new() });

// If the value is global but it was already processed, skip it
if is_global {
Expand All @@ -81,6 +96,7 @@ pub fn update_debug_data<F: PrimeField>(
}

pub fn print_debug_info<F: PrimeField>(
pctx: &ProofCtx<F>,
name: &str,
max_values_to_print: usize,
print_to_file: bool,
Expand Down Expand Up @@ -151,7 +167,7 @@ pub fn print_debug_info<F: PrimeField>(
}
let shared_data = &data.shared_data;
let grouped_data = &mut data.grouped_data;
print_diffs(val, max_values_to_print, shared_data, grouped_data, false, &mut output);
print_diffs(pctx, val, max_values_to_print, shared_data, grouped_data, false, &mut output);
}

if len_overassumed > 0 {
Expand All @@ -176,7 +192,7 @@ pub fn print_debug_info<F: PrimeField>(

let shared_data = &data.shared_data;
let grouped_data = &mut data.grouped_data;
print_diffs(val, max_values_to_print, shared_data, grouped_data, true, &mut output);
print_diffs(pctx, val, max_values_to_print, shared_data, grouped_data, true, &mut output);
}

if len_overproven > 0 {
Expand All @@ -185,6 +201,7 @@ pub fn print_debug_info<F: PrimeField>(
}

fn print_diffs<F: PrimeField>(
pctx: &ProofCtx<F>,
val: &[HintFieldOutput<F>],
max_values_to_print: usize,
shared_data: &SharedData<F>,
Expand All @@ -211,8 +228,8 @@ pub fn print_debug_info<F: PrimeField>(
// Collect and organize rows
let mut organized_rows = Vec::new();
for (airgroup_id, air_id_map) in grouped_data.iter_mut() {
for (air_id, instance_map) in air_id_map.iter_mut() {
for (instance_id, meta_data) in instance_map.iter_mut() {
for (air_id, air_data) in air_id_map.iter_mut() {
for (instance_id, meta_data) in air_data.instances.iter_mut() {
let rows = {
let rows = if proves { &meta_data.row_proves } else { &meta_data.row_assumes };
if rows.is_empty() {
Expand All @@ -230,22 +247,43 @@ pub fn print_debug_info<F: PrimeField>(

// Print grouped rows
for (airgroup_id, air_id, instance_id, mut rows) in organized_rows {
let airgroup_name = pctx.global_info.get_air_group_name(airgroup_id);
let air_name = pctx.global_info.get_air_name(airgroup_id, air_id);
let piop_name = &grouped_data.get(&airgroup_id).unwrap().get(&air_id).unwrap().name_piop;
let expr_name = &grouped_data.get(&airgroup_id).unwrap().get(&air_id).unwrap().name_expr;

rows.sort();
let rows_display =
rows.iter().map(|x| x.to_string()).take(max_values_to_print).collect::<Vec<_>>().join(",");

let truncated = rows.len() > max_values_to_print;
writeln!(
output,
"\t Airgroup: {:<3} | Air: {:<3} | Instance: {:<3} | Num: {:<9} | Rows: [{}{}]",
airgroup_id,
air_id,
instance_id,
rows.len(),
rows_display,
if truncated { ",..." } else { "" },
)
.expect("Write error");
"\t - Airgroup: {} (id: {})",
airgroup_name, airgroup_id
).expect("Write error");
writeln!(
output,
"\t Air: {} (id: {})",
air_name, air_id
).expect("Write error");

writeln!(
output,
"\t PIOP: {}",
piop_name
).expect("Write error");
writeln!(
output,
"\t Expression: {:?}",
expr_name
).expect("Write error");

writeln!(
output,
"\t Instance ID: {} | Num: {} | Rows: [{}{}]",
instance_id, rows.len(), rows_display, if truncated { ",..." } else { "" }
).expect("Write error");
}

writeln!(output, "\t --------------------------------------------------").expect("Write error");
Expand Down
Loading

0 comments on commit 6855cd0

Please sign in to comment.