Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugs in memory and binary extension #195

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion state-machines/arith/src/arith_full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ impl<F: PrimeField> ArithFullSM<F> {

if !binary_inputs.is_empty() {
timer_start_trace!(ARITH_BINARY);
info!("{}: ··· calling binary_sm", Self::MY_NAME);
self.binary_sm.prove(binary_inputs.as_slice(), false);
timer_stop_and_log_trace!(ARITH_BINARY);
}
Expand Down
2 changes: 1 addition & 1 deletion state-machines/binary/src/binary_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ impl<F: Field> BinaryBasicSM<F> {
assert!(operations.len() <= air.num_rows());

info!(
"{}: ··· Creating Binary basic instance [{} / {} rows filled {:.2}%]",
"{}: ··· Creating Binary instance [{} / {} rows filled {:.2}%]",
Self::MY_NAME,
operations.len(),
air.num_rows(),
Expand Down
28 changes: 20 additions & 8 deletions state-machines/binary/src/binary_basic_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::{
Arc, Mutex,
};

use log::info;
use p3_field::Field;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::AirInstance;
Expand Down Expand Up @@ -261,19 +260,32 @@ impl<F: Field> BinaryBasicTableSM<F> {

if is_myne {
// Create the prover buffer
let trace: BinaryTableTrace<'_, _> = BinaryTableTrace::new(self.num_rows);
let num_rows = self.num_rows;
let trace: BinaryTableTrace<'_, _> = BinaryTableTrace::new(num_rows);
let mut prover_buffer = trace.buffer.unwrap();

prover_buffer[0..self.num_rows]
let non_zero_multiplicities = prover_buffer[0..num_rows]
.par_iter_mut()
.enumerate()
.for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i]));

info!(
"{}: ··· Creating Binary basic table instance [{} rows filled 100%]",
.map(|(i, input)| {
*input = F::from_canonical_u64(multiplicity_[i]);
if multiplicity_[i] != 0 {
Some(1)
} else {
None
}
})
.filter_map(|x| x)
.sum::<usize>();

log::info!(
"{}: ··· Creating Binary Table instance [{} / {} rows used {:.2}%]",
Self::MY_NAME,
self.num_rows,
non_zero_multiplicities,
num_rows,
non_zero_multiplicities as f64 / num_rows as f64 * 100.0
);

let air_instance = AirInstance::new(
self.wcm.get_sctx(),
ZISK_AIRGROUP_ID,
Expand Down
4 changes: 2 additions & 2 deletions state-machines/binary/src/binary_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ impl<F: PrimeField> BinaryExtensionSM<F> {
if ((a_bytes[j] as u64) & SIGN_BYTE) != 0 {
out = (a_bytes[j] as u64) << 8 | SE_MASK_16;
} else {
out = a_bytes[j] as u64;
out = (a_bytes[j] as u64) << 8;
}
} else {
out = 0;
Expand Down Expand Up @@ -391,7 +391,7 @@ impl<F: PrimeField> BinaryExtensionSM<F> {
assert!(operations.len() <= air.num_rows());

info!(
"{}: ··· Creating Binary extension instance [{} / {} rows filled {:.2}%]",
"{}: ··· Creating Binary Extension instance [{} / {} rows filled {:.2}%]",
Self::MY_NAME,
operations.len(),
air.num_rows(),
Expand Down
28 changes: 19 additions & 9 deletions state-machines/binary/src/binary_extension_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::{
Arc, Mutex,
};

use log::info;
use p3_field::Field;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::AirInstance;
Expand Down Expand Up @@ -136,19 +135,30 @@ impl<F: Field> BinaryExtensionTableSM<F> {
dctx.distribute_multiplicity(&mut multiplicity_, owner);

if is_myne {
let trace: BinaryExtensionTableTrace<'_, _> =
BinaryExtensionTableTrace::new(self.num_rows);
let num_rows = self.num_rows;
let trace: BinaryExtensionTableTrace<'_, _> = BinaryExtensionTableTrace::new(num_rows);
let mut prover_buffer = trace.buffer.unwrap();

prover_buffer[0..self.num_rows]
let non_zero_multiplicities = prover_buffer[0..num_rows]
.par_iter_mut()
.enumerate()
.for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i]));

info!(
"{}: ··· Creating Binary extension table instance [{} rows filled 100%]",
.map(|(i, input)| {
*input = F::from_canonical_u64(multiplicity_[i]);
if multiplicity_[i] != 0 {
Some(1)
} else {
None
}
})
.filter_map(|x| x)
.sum::<usize>();

log::info!(
"{}: ··· Creating Binary Extension Table instance [{} / {} rows used {:.2}%]",
Self::MY_NAME,
self.num_rows,
non_zero_multiplicities,
num_rows,
non_zero_multiplicities as f64 / num_rows as f64 * 100.0
);

let air_instance = AirInstance::new(
Expand Down
13 changes: 9 additions & 4 deletions state-machines/mem/src/mem_align_rom_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::{
},
};

use log::info;
use p3_field::PrimeField;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::AirInstance;
Expand Down Expand Up @@ -43,7 +42,7 @@ pub enum ExtensionTableSMErr {
}

impl<F: PrimeField> MemAlignRomSM<F> {
const MY_NAME: &'static str = "MemAlignRom";
const MY_NAME: &'static str = "MemAlROM";

pub fn new(wcm: Arc<WitnessManager<F>>) -> Arc<Self> {
let pctx = wcm.get_pctx();
Expand Down Expand Up @@ -196,9 +195,15 @@ impl<F: PrimeField> MemAlignRomSM<F> {
trace_buffer[*row_idx as usize] =
MemAlignRomRow { multiplicity: F::from_canonical_u64(*multiplicity) };
}
}

info!("{}: ··· Creating Mem Align Rom instance", Self::MY_NAME,);
log::info!(
"{}: ··· Creating Mem Align ROM instance [{} / {} rows executed {:.2}%]",
Self::MY_NAME,
multiplicity.len(),
air_mem_align_rom_rows,
multiplicity.len() as f64 / air_mem_align_rom_rows as f64 * 100.0
);
}

let air_instance = AirInstance::new(
sctx,
Expand Down
14 changes: 12 additions & 2 deletions state-machines/mem/src/mem_align_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,7 @@ impl<F: PrimeField> MemAlignSM<F> {

pub fn prove(&self, computed_rows: &[MemAlignRow<F>]) {
if let Ok(mut rows) = self.rows.lock() {
let previous_num_rows = rows.len();
rows.extend_from_slice(computed_rows);

#[cfg(feature = "debug_mem_align")]
Expand All @@ -924,8 +925,17 @@ impl<F: PrimeField> MemAlignSM<F> {
let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]);

while rows.len() >= air_mem_align.num_rows() {
let num_drained = std::cmp::min(air_mem_align.num_rows(), rows.len());
let drained_rows = rows.drain(..num_drained).collect::<Vec<_>>();
// Find the correct cutting point
let cutting_point =
if previous_num_rows + computed_rows.len() == air_mem_align.num_rows() {
air_mem_align.num_rows()
} else {
// This is the case where previous_num_rows + computed_rows.len() >
// air_mem_align.num_rows() In this case, we prove
// computed_rows in the next air instance
previous_num_rows
};
let drained_rows = rows.drain(..cutting_point).collect::<Vec<_>>();

self.fill_new_air_instance(&drained_rows);
}
Expand Down
13 changes: 8 additions & 5 deletions state-machines/mem/src/mem_proxy_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ use crate::{
MAX_MAIN_STEP, MAX_MEM_ADDR, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_STEP, MAX_MEM_STEP_OFFSET,
MEMORY_MAX_DIFF, MEM_ADDR_MASK, MEM_BYTES, MEM_BYTES_BITS,
};
use log::info;

use p3_field::PrimeField;
use proofman_util::{timer_start_debug, timer_stop_and_log_debug};
Expand All @@ -104,7 +103,7 @@ macro_rules! debug_info {
($prefix:expr, $($arg:tt)*) => {
#[cfg(feature = "debug_mem_proxy_engine")]
{
info!(concat!("MemProxy: ",$prefix), $($arg)*);
log::info!(concat!("MemProxy: ",$prefix), $($arg)*);
}
};
}
Expand Down Expand Up @@ -133,6 +132,7 @@ pub struct MemProxyEngine<F: PrimeField> {
mem_align_sm: Arc<MemAlignSM<F>>,
next_open_addr: u32,
next_open_step: u64,
last_value: u64,
last_addr: u32,
last_step: u64,
intermediate_cases: u32,
Expand All @@ -156,6 +156,7 @@ impl<F: PrimeField> MemProxyEngine<F> {
mem_align_sm,
next_open_addr: NO_OPEN_ADDR,
next_open_step: NO_OPEN_STEP,
last_value: 0,
last_addr: 0xFFFF_FFFF,
last_step: 0,
intermediate_cases: 0,
Expand Down Expand Up @@ -370,11 +371,12 @@ impl<F: PrimeField> MemProxyEngine<F> {

// check if step difference is too large
if self.last_addr == w_addr && (step - self.last_step) > MEMORY_MAX_DIFF {
self.push_intermediate_internal_reads(w_addr, value, self.last_step, step);
self.push_intermediate_internal_reads(w_addr, self.last_value, self.last_step, step);
}

self.last_step = step;
self.last_addr = w_addr;
self.last_value = value;

let mem_op = MemInput { step, is_write, is_internal: false, addr: w_addr, value };
debug_info!(
Expand Down Expand Up @@ -542,9 +544,10 @@ impl<F: PrimeField> MemProxyEngine<F> {
);
module.send_inputs(&self.modules_data[module_id].inputs);
}
info!(
debug_info!(
"MemProxy: ··· Intermediate reads [cases:{} steps:{}]",
self.intermediate_cases, self.intermediate_steps
self.intermediate_cases,
self.intermediate_steps
);
}
/// Fetches the address map, defining and calculating all necessary structures to manage the
Expand Down
2 changes: 0 additions & 2 deletions state-machines/mem/src/mem_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ pub struct MemPreviousSegment {
#[allow(unused, unused_variables)]
impl<F: PrimeField> MemSM<F> {
pub fn new(wcm: Arc<WitnessManager<F>>, std: Arc<Std<F>>) -> Arc<Self> {
let pctx = wcm.get_pctx();
let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]);
let mem_sm =
Self { wcm: wcm.clone(), std: std.clone(), registered_predecessors: AtomicU32::new(0) };
let mem_sm = Arc::new(mem_sm);
Expand Down
16 changes: 13 additions & 3 deletions state-machines/rom/src/rom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub struct RomSM<F> {
}

impl<F: Field> RomSM<F> {
const MY_NAME: &'static str = "ROM ";

pub fn new(wcm: Arc<WitnessManager<F>>) -> Arc<Self> {
let rom_sm = Self { wcm: wcm.clone() };
let rom_sm = Arc::new(rom_sm);
Expand All @@ -42,9 +44,9 @@ impl<F: Field> RomSM<F> {

// Create an empty ROM trace
let pilout = Pilout::pilout();
let num_rows = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows();
let rom_trace_len = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows();

let mut rom_trace = RomTrace::new(num_rows);
let mut rom_trace = RomTrace::new(rom_trace_len);

// For every instruction in the rom, fill its corresponding ROM trace
let main_trace_len = pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows() as u64;
Expand Down Expand Up @@ -73,10 +75,18 @@ impl<F: Field> RomSM<F> {
}

// Padd with zeroes
for i in rom.insts.len()..num_rows {
for i in rom.insts.len()..rom_trace_len {
rom_trace[i] = RomRow::default();
}

log::info!(
"{}: ··· Creating ROM instance [{} / {} rows executed {:.2}%]",
Self::MY_NAME,
pc_histogram.map.len(),
rom_trace_len,
pc_histogram.map.len() as f64 / rom_trace_len as f64 * 100.0
);

let mut air_instance = AirInstance::new(
sctx.clone(),
ZISK_AIRGROUP_ID,
Expand Down
Loading