From c0c45a5a47a0f2e6dd568faea49c804a7c5fb7c9 Mon Sep 17 00:00:00 2001 From: fractasy <89866610+fractasy@users.noreply.github.com> Date: Thu, 21 Nov 2024 09:36:42 +0100 Subject: [PATCH] Port arith SMs to use ZiskOp instead of hardcoded opcodes (#170) --- state-machines/arith/src/arith_constants.rs | 14 --- state-machines/arith/src/arith_full.rs | 7 +- state-machines/arith/src/arith_operation.rs | 90 +++++++++++-------- .../arith/src/arith_operation_test.rs | 72 +++++++++------ state-machines/arith/src/lib.rs | 2 - 5 files changed, 102 insertions(+), 83 deletions(-) delete mode 100644 state-machines/arith/src/arith_constants.rs diff --git a/state-machines/arith/src/arith_constants.rs b/state-machines/arith/src/arith_constants.rs deleted file mode 100644 index 4a7af91a..00000000 --- a/state-machines/arith/src/arith_constants.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub const MULU: u8 = 0xb0; -pub const MULUH: u8 = 0xb1; -pub const MULSUH: u8 = 0xb3; -pub const MUL: u8 = 0xb4; -pub const MULH: u8 = 0xb5; -pub const MUL_W: u8 = 0xb6; -pub const DIVU: u8 = 0xb8; -pub const REMU: u8 = 0xb9; -pub const DIV: u8 = 0xba; -pub const REM: u8 = 0xbb; -pub const DIVU_W: u8 = 0xbc; -pub const REMU_W: u8 = 0xbd; -pub const DIV_W: u8 = 0xbe; -pub const REM_W: u8 = 0xbf; diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 96590ba3..3c83a887 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -4,15 +4,14 @@ use std::sync::{ }; use crate::{ - arith_constants::*, ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithTableInputs, - ArithTableSM, + ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithTableInputs, ArithTableSM, }; use log::info; use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; use sm_common::i64_to_u64_field; -use zisk_core::ZiskRequiredOperation; +use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::*; pub struct ArithFullSM { @@ -197,7 +196,7 @@ impl ArithFullSM { if padding_rows > 0 { let mut t: ArithRow = Default::default(); - let padding_opcode = MULUH; + let padding_opcode = ZiskOp::Muluh.code(); t.op = F::from_canonical_u8(padding_opcode); t.fab = F::one(); for i in padding_offset..num_rows { diff --git a/state-machines/arith/src/arith_operation.rs b/state-machines/arith/src/arith_operation.rs index a4eecdf2..c6b8b48a 100644 --- a/state-machines/arith/src/arith_operation.rs +++ b/state-machines/arith/src/arith_operation.rs @@ -1,4 +1,6 @@ -use crate::{arith_constants::*, arith_range_table_helpers::*}; +use zisk_core::zisk_ops::ZiskOp; + +use crate::arith_range_table_helpers::*; use std::fmt; pub struct ArithOperation { @@ -132,19 +134,21 @@ impl ArithOperation { self.input_a = input_a; self.input_b = input_b; self.div_by_zero = input_b == 0 && - (op == DIV || - op == REM || - op == DIV_W || - op == REM_W || - op == DIVU || - op == REMU || - op == DIVU_W || - op == REMU_W); - - self.div_overflow = ((op == DIV || op == REM) && + (op == ZiskOp::Div.code() || + op == ZiskOp::Rem.code() || + op == ZiskOp::DivW.code() || + op == ZiskOp::RemW.code() || + op == ZiskOp::Divu.code() || + op == ZiskOp::Remu.code() || + op == ZiskOp::DivuW.code() || + op == ZiskOp::RemuW.code()); + + self.div_overflow = ((op == ZiskOp::Div.code() || op == ZiskOp::Rem.code()) && input_a == 0x8000_0000_0000_0000 && input_b == 0xFFFF_FFFF_FFFF_FFFF) || - ((op == DIV_W || op == REM_W) && input_a == 0x8000_0000 && input_b == 0xFFFF_FFFF); + ((op == ZiskOp::DivW.code() || op == ZiskOp::RemW.code()) && + input_a == 0x8000_0000 && + input_b == 0xFFFF_FFFF); let [a, b, c, d] = Self::calculate_abcd_from_ab(op, input_a, input_b); self.a = Self::u64_to_chunks(a); @@ -298,26 +302,35 @@ impl ArithOperation { } fn calculate_abcd_from_ab(op: u8, a: u64, b: u64) -> [u64; 4] { - match op { - MULU | MULUH => { + let zisk_op = ZiskOp::try_from_code(op).unwrap(); + match zisk_op { + ZiskOp::Mulu | ZiskOp::Muluh => { let c: u128 = a as u128 * b as u128; [a, b, c as u64, (c >> 64) as u64] } - MULSUH => { + ZiskOp::Mulsuh => { let [c, d] = Self::calculate_mulsu(a, b); [a, b, c, d] } - MUL | MULH => { + ZiskOp::Mul | ZiskOp::Mulh => { let [c, d] = Self::calculate_mul(a, b); [a, b, c, d] } - MUL_W => [a, b, Self::calculate_mul_w(a, b), 0], - DIVU | REMU => [Self::calculate_divu(a, b), b, a, Self::calculate_remu(a, b)], - DIVU_W | REMU_W => [Self::calculate_divu_w(a, b), b, a, Self::calculate_remu_w(a, b)], - DIV | REM => [Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)], - DIV_W | REM_W => [Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)], + ZiskOp::MulW => [a, b, Self::calculate_mul_w(a, b), 0], + ZiskOp::Divu | ZiskOp::Remu => { + [Self::calculate_divu(a, b), b, a, Self::calculate_remu(a, b)] + } + ZiskOp::DivuW | ZiskOp::RemuW => { + [Self::calculate_divu_w(a, b), b, a, Self::calculate_remu_w(a, b)] + } + ZiskOp::Div | ZiskOp::Rem => { + [Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)] + } + ZiskOp::DivW | ZiskOp::RemW => { + [Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)] + } _ => { - panic!("Invalid opcode"); + panic!("ArithOperation::calculate_abcd_from_ab() Invalid opcode={}", op); } } } @@ -351,49 +364,50 @@ impl ArithOperation { let mut sb = false; let mut rem = false; - match op { - MULU => { + let zisk_op = ZiskOp::try_from_code(op).unwrap(); + match zisk_op { + ZiskOp::Mulu => { self.main_mul = true; } - MULUH => {} - MULSUH => { + ZiskOp::Muluh => {} + ZiskOp::Mulsuh => { sa = true; } - MUL => { + ZiskOp::Mul => { sa = true; sb = true; self.main_mul = true; } - MULH => { + ZiskOp::Mulh => { sa = true; sb = true; } - MUL_W => { + ZiskOp::MulW => { self.m32 = true; self.sext = ((a * b) & 0xFFFF_FFFF) & 0x8000_0000 != 0; self.main_mul = true; } - DIVU => { + ZiskOp::Divu => { self.div = true; self.main_div = true; } - REMU => { + ZiskOp::Remu => { self.div = true; rem = true; } - DIV => { + ZiskOp::Div => { sa = true; sb = true; self.div = true; self.main_div = true; } - REM => { + ZiskOp::Rem => { sa = true; sb = true; rem = true; self.div = true; } - DIVU_W => { + ZiskOp::DivuW => { // divu_w, remu_w self.div = true; self.m32 = true; @@ -401,7 +415,7 @@ impl ArithOperation { self.sext = (a & 0x8000_0000) != 0; self.main_div = true; } - REMU_W => { + ZiskOp::RemuW => { // divu_w, remu_w self.div = true; self.m32 = true; @@ -409,7 +423,7 @@ impl ArithOperation { // use d in bus self.sext = (d & 0x8000_0000) != 0; } - DIV_W => { + ZiskOp::DivW => { // div_w, rem_w sa = true; sb = true; @@ -419,7 +433,7 @@ impl ArithOperation { self.sext = (a & 0x8000_0000) != 0; self.main_div = true; } - REM_W => { + ZiskOp::RemW => { // div_w, rem_w sa = true; sb = true; @@ -430,7 +444,7 @@ impl ArithOperation { self.sext = (d & 0x8000_0000) != 0; } _ => { - panic!("Invalid opcode"); + panic!("ArithOperation::update_flags_and_ranges() Invalid opcode={}", op); } } self.signed = sa || sb; diff --git a/state-machines/arith/src/arith_operation_test.rs b/state-machines/arith/src/arith_operation_test.rs index 75f1aa4a..37e0a4c7 100644 --- a/state-machines/arith/src/arith_operation_test.rs +++ b/state-machines/arith/src/arith_operation_test.rs @@ -1,8 +1,6 @@ use zisk_core::zisk_ops::*; -use crate::{ - arith_constants::*, arith_table_data, ArithOperation, ArithRangeTableHelpers, ArithTableHelpers, -}; +use crate::{arith_table_data, ArithOperation, ArithRangeTableHelpers, ArithTableHelpers}; const MIN_N_64: u64 = 0x8000_0000_0000_0000; const MIN_N_32: u64 = 0x0000_0000_8000_0000; @@ -30,8 +28,22 @@ const ALL_VALUES: [u64; 16] = [ MAX_64, ]; -const ALL_OPERATIONS: [u8; 14] = - [MUL, MULH, MULSUH, MULU, MULUH, DIVU, REMU, DIV, REM, MUL_W, DIVU_W, REMU_W, DIV_W, REM_W]; +const ALL_OPERATIONS: [u8; 14] = [ + ZiskOp::Mul.code(), + ZiskOp::Mulh.code(), + ZiskOp::Mulsuh.code(), + ZiskOp::Mulu.code(), + ZiskOp::Muluh.code(), + ZiskOp::Divu.code(), + ZiskOp::Remu.code(), + ZiskOp::Div.code(), + ZiskOp::Rem.code(), + ZiskOp::MulW.code(), + ZiskOp::DivuW.code(), + ZiskOp::RemuW.code(), + ZiskOp::DivW.code(), + ZiskOp::RemW.code(), +]; struct ArithOperationTest { count: u32, @@ -87,30 +99,40 @@ impl ArithOperationTest { } fn is_m32_op(op: u8) -> bool { - match op { - MUL | MULH | MULSUH | MULU | MULUH | DIVU | REMU | DIV | REM => false, - MUL_W | DIVU_W | REMU_W | DIV_W | REM_W => true, - _ => panic!("Invalid opcode"), + let zisk_op = ZiskOp::try_from_code(op).unwrap(); + match zisk_op { + ZiskOp::Mul | + ZiskOp::Mulh | + ZiskOp::Mulsuh | + ZiskOp::Mulu | + ZiskOp::Muluh | + ZiskOp::Divu | + ZiskOp::Remu | + ZiskOp::Div | + ZiskOp::Rem => false, + ZiskOp::MulW | ZiskOp::DivuW | ZiskOp::RemuW | ZiskOp::DivW | ZiskOp::RemW => true, + _ => panic!("ArithOperationTest::is_m32_op() Invalid opcode={}", op), } } fn calculate_emulator_res(op: u8, a: u64, b: u64) -> (u64, bool) { - match op { - MULU => op_mulu(a, b), - MULUH => op_muluh(a, b), - MULSUH => op_mulsuh(a, b), - MUL => op_mul(a, b), - MULH => op_mulh(a, b), - MUL_W => op_mul_w(a, b), - DIVU => op_divu(a, b), - REMU => op_remu(a, b), - DIVU_W => op_divu_w(a, b), - REMU_W => op_remu_w(a, b), - DIV => op_div(a, b), - REM => op_rem(a, b), - DIV_W => op_div_w(a, b), - REM_W => op_rem_w(a, b), + let zisk_op = ZiskOp::try_from_code(op).unwrap(); + match zisk_op { + ZiskOp::Mulu => op_mulu(a, b), + ZiskOp::Muluh => op_muluh(a, b), + ZiskOp::Mulsuh => op_mulsuh(a, b), + ZiskOp::Mul => op_mul(a, b), + ZiskOp::Mulh => op_mulh(a, b), + ZiskOp::MulW => op_mul_w(a, b), + ZiskOp::Divu => op_divu(a, b), + ZiskOp::Remu => op_remu(a, b), + ZiskOp::DivuW => op_divu_w(a, b), + ZiskOp::RemuW => op_remu_w(a, b), + ZiskOp::Div => op_div(a, b), + ZiskOp::Rem => op_rem(a, b), + ZiskOp::DivW => op_div_w(a, b), + ZiskOp::RemW => op_rem_w(a, b), _ => { - panic!("Invalid opcode"); + panic!("ArithOperationTest::calculate_emulator_res() Invalid opcode={}", op); } } } diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index d7250835..714731cd 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -1,5 +1,4 @@ mod arith; -mod arith_constants; mod arith_full; mod arith_operation; mod arith_range_table; @@ -12,7 +11,6 @@ mod arith_table_helpers; mod arith_operation_test; pub use arith::*; -pub use arith_constants::*; pub use arith_full::*; pub use arith_operation::*; pub use arith_range_table::*;