Skip to content

Commit

Permalink
STD debug mode improvements and fixes (#126)
Browse files Browse the repository at this point in the history
* wip

* First improved debug version done

* Fixe issue overcounting when expressions are of degree 0

* Cargo clippy
  • Loading branch information
hecmas authored Dec 3, 2024
1 parent 414130f commit bfd61ba
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 172 deletions.
15 changes: 7 additions & 8 deletions pil2-components/lib/std/pil/std_prod.pil
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private function init_proof_containers_prod(int name, int opid) {
}
}

private function init_containers_prod(int name, int opid) {
private function init_air_containers_prod(int name, int opid) {
container air.std.gprod {
int gprod_assumes_count = 0;
expr gprod_assumes_sel[100];
Expand Down Expand Up @@ -104,17 +104,16 @@ private function update_piop_prod(int name, int proves, int opid, expr sel, expr
init_proof_containers_prod(name, opid);

if (direct_type == PIOP_DIRECT_TYPE_AIR || direct_type == PIOP_DIRECT_TYPE_DEFAULT) {
init_containers_prod(name, opid);
}
init_air_containers_prod(name, opid);

if (direct_type == PIOP_DIRECT_TYPE_DEFAULT) {
// Create debug hints for the witness computation
const int ncols = length(cols);
string name_cols[ncols];
for (int i = 0; i < ncols; i++) {
string name_cols[cols_count];
expr sum_expr = 0;
for (int i = 0; i < cols_count; i++) {
name_cols[i] = string(cols[i]);
sum_expr += cols[i];
}
@gprod_member_data{name_piop: get_piop_name(name), names: name_cols, opid: opid, proves: proves, selector: sel, references: cols};
@gprod_member_data{name_piop: get_piop_name(name), name_expr: name_expr, opid: opid, proves: proves, selector: sel, expressions: cols, deg_expr: degree(sum_expr), deg_sel: degree(sel)};
}

initial_checks_prod(proves, opid, cols, direct_type);
Expand Down
17 changes: 8 additions & 9 deletions pil2-components/lib/std/pil/std_sum.pil
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private function init_proof_containers_sum(int name, int opid[]) {
}
}

private function init_containers_sum(int name, int opid[]) {
private function init_air_containers_sum(int name, int opid[]) {
container air.std.gsum {
int gsum_nargs = 0;
expr gsum_s[100];
Expand Down Expand Up @@ -112,18 +112,17 @@ private function update_piop_sum(int name, int proves, int opid[], expr sumid, e
init_proof_containers_sum(name, opid);

if (direct_type == PIOP_DIRECT_TYPE_AIR || direct_type == PIOP_DIRECT_TYPE_DEFAULT) {
init_containers_sum(name, opid);
}
init_air_containers_sum(name, opid);

if (direct_type == PIOP_DIRECT_TYPE_DEFAULT) {
// Create debug hints for the witness computation
const int ncols = length(cols);
string name_cols[ncols];
for (int i = 0; i < ncols; i++) {
name_cols[i] = string(cols[i]);
string name_expr[cols_count];
expr sum_expr = 0;
for (int i = 0; i < cols_count; i++) {
name_expr[i] = string(cols[i]);
sum_expr += cols[i];
}
// proves = 2 marks that the user is responsible for the use of proves and assumes
@gsum_member_data{name_piop: get_piop_name(name), names: name_cols, sumid: sumid, proves: proves == 2 ? sel : proves, selector: sel, references: cols};
@gsum_member_data{name_piop: get_piop_name(name), name_expr: name_expr, opid: sumid, proves: proves == 2 ? sel : proves, selector: sel, expressions: cols, deg_expr: degree(sum_expr), deg_sel: degree(sel)};
}

initial_checks_sum(proves, opid, cols, direct_type);
Expand Down
168 changes: 126 additions & 42 deletions pil2-components/lib/std/rs/src/debug.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,74 @@
use std::{collections::HashMap, sync::Mutex};

use num_traits::ToPrimitive;
use p3_field::PrimeField;
use proofman_hints::{format_vec, HintFieldOutput};

pub type DebugData<F> = Mutex<HashMap<F, HashMap<Vec<HintFieldOutput<F>>, BusValue<F>>>>; // opid -> val -> BusValue

pub struct BusValue<F: PrimeField> {
pub num_proves: F,
pub num_assumes: F,
// meta data
pub row_proves: Vec<usize>,
pub row_assumes: Vec<usize>,
pub struct BusValue<F> {
shared_data: SharedData<F>, // Data shared across all airgroups, airs, and instances
grouped_data: AirGroupMap, // Data grouped by: airgroup_id -> air_id -> instance_id -> MetaData
}

struct SharedData<F> {
num_proves: F,
num_assumes: F,
}

type AirGroupMap = HashMap<usize, AirIdMap>;
type AirIdMap = HashMap<usize, InstanceMap>;
type InstanceMap = HashMap<usize, MetaData>;

struct MetaData {
row_proves: Vec<usize>,
row_assumes: Vec<usize>,
}

#[allow(clippy::too_many_arguments)]
pub fn update_debug_data<F: PrimeField>(
debug_data: &DebugData<F>,
opid: F,
val: Vec<HintFieldOutput<F>>,
airgroup_id: usize,
air_id: usize,
instance_id: usize,
row: usize,
proves: bool,
times: F,
) {
let mut bus = debug_data.lock().expect("Bus values missing");

let bus_opid = bus.entry(opid).or_default();

let bus_val = bus_opid.entry(val).or_insert_with(|| BusValue {
shared_data: SharedData { num_proves: F::zero(), num_assumes: F::zero() },
grouped_data: AirGroupMap::new(),
});

let grouped_data = bus_val
.grouped_data
.entry(airgroup_id)
.or_default()
.entry(air_id)
.or_default()
.entry(instance_id)
.or_insert_with(|| MetaData { row_proves: Vec::new(), row_assumes: Vec::new() });

if proves {
bus_val.shared_data.num_proves += times;
grouped_data.row_proves.push(row);
} else {
assert!(times.is_one(), "The selector value is invalid: expected 1, but received {:?}.", times);
bus_val.shared_data.num_assumes += times;
grouped_data.row_assumes.push(row);
}
}

pub fn print_debug_info<F: PrimeField>(name: &str, max_values_to_print: usize, debug_data: &DebugData<F>) {
let mut there_are_errors = false;
let mut bus_vals = debug_data.lock().expect("Bus values missing");
for (opid, bus) in bus_vals.iter_mut() {
if bus.iter().any(|(_, v)| v.num_proves != v.num_assumes) {
if bus.iter().any(|(_, v)| v.shared_data.num_proves != v.shared_data.num_assumes) {
if !there_are_errors {
there_are_errors = true;
log::error!("{}: Some bus values do not match.", name);
Expand All @@ -30,7 +80,7 @@ pub fn print_debug_info<F: PrimeField>(name: &str, max_values_to_print: usize, d

// TODO: Sort unmatching values by the row
let mut overassumed_values: Vec<(&Vec<HintFieldOutput<F>>, &mut BusValue<F>)> =
bus.iter_mut().filter(|(_, v)| v.num_proves < v.num_assumes).collect();
bus.iter_mut().filter(|(_, v)| v.shared_data.num_proves < v.shared_data.num_assumes).collect();
let len_overassumed = overassumed_values.len();

if len_overassumed > 0 {
Expand All @@ -42,7 +92,9 @@ pub fn print_debug_info<F: PrimeField>(name: &str, max_values_to_print: usize, d
println!("\t ...");
break;
}
print_diffs(val, max_values_to_print, data.num_assumes, data.num_proves, &mut data.row_assumes, false);
let shared_data = &data.shared_data;
let grouped_data = &mut data.grouped_data;
print_diffs(val, max_values_to_print, shared_data, grouped_data, false);
}

if len_overassumed > 0 {
Expand All @@ -51,7 +103,7 @@ pub fn print_debug_info<F: PrimeField>(name: &str, max_values_to_print: usize, d

// TODO: Sort unmatching values by the row
let mut overproven_values: Vec<(&Vec<HintFieldOutput<F>>, &mut BusValue<F>)> =
bus.iter_mut().filter(|(_, v)| v.num_proves > v.num_assumes).collect();
bus.iter_mut().filter(|(_, v)| v.shared_data.num_proves > v.shared_data.num_assumes).collect();
let len_overproven = overproven_values.len();

if len_overproven > 0 {
Expand All @@ -63,7 +115,10 @@ pub fn print_debug_info<F: PrimeField>(name: &str, max_values_to_print: usize, d
println!("\t ...");
break;
}
print_diffs(val, max_values_to_print, data.num_proves, data.num_assumes, &mut data.row_proves, true);

let shared_data = &data.shared_data;
let grouped_data = &mut data.grouped_data;
print_diffs(val, max_values_to_print, shared_data, grouped_data, true);
}

if len_overproven > 0 {
Expand All @@ -74,39 +129,68 @@ pub fn print_debug_info<F: PrimeField>(name: &str, max_values_to_print: usize, d
fn print_diffs<F: PrimeField>(
val: &[HintFieldOutput<F>],
max_values_to_print: usize,
num_vals_left: F,
num_vals_right: F,
rows: &mut [usize],
reverse_print: bool,
shared_data: &SharedData<F>,
grouped_data: &mut AirGroupMap,
proves: bool,
) {
let diff = num_vals_left - num_vals_right;
let diff = diff.as_canonical_biguint().to_usize().expect("Cannot convert to usize");

rows.sort();
let rows = rows
.iter()
.map(|x| x.to_string())
.take(std::cmp::min(max_values_to_print, diff))
.collect::<Vec<_>>()
.join(",");

let name_str = match rows.len() {
1 => format!("at row {}.", rows),
len if max_values_to_print < len => format!("at rows {},...", rows),
_ => format!("at rows {}.", rows),
};
let diff_str = if diff == 1 { "time" } else { "times" };

let (num_assumes, num_proves) =
if reverse_print { (num_vals_right, num_vals_left) } else { (num_vals_left, num_vals_right) };
let num_assumes = shared_data.num_assumes;
let num_proves = shared_data.num_proves;

let num = if proves { num_proves } else { num_assumes };
let num_str = if num.is_one() { "time" } else { "times" };

println!("\t ==================================================");
println!(
"\t • Value:\n\t {}\n\t Appears {} {} {}\n\t Num Assumes: {}.\n\t Num Proves: {}.",
"\t • Value:\n\t {}\n\t Appears {} {} across the following:",
format_vec(val),
diff,
diff_str,
name_str,
num_assumes,
num_proves
num,
num_str,
);

// Collect and organize rows
let mut organized_rows = Vec::new();
for (airgroup_id, air_id_map) in grouped_data.iter_mut() {
for (air_id, instance_map) in air_id_map.iter_mut() {
for (instance_id, meta_data) in instance_map.iter_mut() {
let rows = {
let rows = if proves { &meta_data.row_proves } else { &meta_data.row_assumes };
if rows.is_empty() {
continue;
}
rows.clone()
};
organized_rows.push((*airgroup_id, *air_id, *instance_id, rows));
}
}
}

// Sort rows by airgroup_id, air_id, and instance_id
organized_rows.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)).then(a.2.cmp(&b.2)));

// Print grouped rows
for (airgroup_id, air_id, instance_id, mut rows) in organized_rows {
rows.sort();
let rows_display =
rows.iter().map(|x| x.to_string()).take(max_values_to_print).collect::<Vec<_>>().join(",");

let truncated = rows.len() > max_values_to_print;
println!(
"\t Airgroup: {:<3} | Air: {:<3} | Instance: {:<3} | Num: {:<9} | Rows: [{}{}]",
airgroup_id,
air_id,
instance_id,
rows.len(),
rows_display,
if truncated { ",..." } else { "" },
);
}

println!("\t --------------------------------------------------");
let diff = if proves { num_proves - num_assumes } else { num_assumes - num_proves };
println!(
"\t Total Num Assumes: {}.\n\t Total Num Proves: {}.\n\t Total Unmatched: {}.",
num_assumes, num_proves, diff
);
println!("\t ==================================================\n");
}
}
Loading

0 comments on commit bfd61ba

Please sign in to comment.