Skip to content

Commit

Permalink
Port arith SMs to use ZiskOp instead of hardcoded opcodes (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
fractasy authored Nov 21, 2024
1 parent ca620d3 commit c0c45a5
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 83 deletions.
14 changes: 0 additions & 14 deletions state-machines/arith/src/arith_constants.rs

This file was deleted.

7 changes: 3 additions & 4 deletions state-machines/arith/src/arith_full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ use std::sync::{
};

use crate::{
arith_constants::*, ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithTableInputs,
ArithTableSM,
ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithTableInputs, ArithTableSM,
};
use log::info;
use p3_field::Field;
use proofman::{WitnessComponent, WitnessManager};
use proofman_util::{timer_start_trace, timer_stop_and_log_trace};
use sm_common::i64_to_u64_field;
use zisk_core::ZiskRequiredOperation;
use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation};
use zisk_pil::*;

pub struct ArithFullSM<F> {
Expand Down Expand Up @@ -197,7 +196,7 @@ impl<F: Field> ArithFullSM<F> {

if padding_rows > 0 {
let mut t: ArithRow<F> = Default::default();
let padding_opcode = MULUH;
let padding_opcode = ZiskOp::Muluh.code();
t.op = F::from_canonical_u8(padding_opcode);
t.fab = F::one();
for i in padding_offset..num_rows {
Expand Down
90 changes: 52 additions & 38 deletions state-machines/arith/src/arith_operation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{arith_constants::*, arith_range_table_helpers::*};
use zisk_core::zisk_ops::ZiskOp;

use crate::arith_range_table_helpers::*;
use std::fmt;

pub struct ArithOperation {
Expand Down Expand Up @@ -132,19 +134,21 @@ impl ArithOperation {
self.input_a = input_a;
self.input_b = input_b;
self.div_by_zero = input_b == 0 &&
(op == DIV ||
op == REM ||
op == DIV_W ||
op == REM_W ||
op == DIVU ||
op == REMU ||
op == DIVU_W ||
op == REMU_W);

self.div_overflow = ((op == DIV || op == REM) &&
(op == ZiskOp::Div.code() ||
op == ZiskOp::Rem.code() ||
op == ZiskOp::DivW.code() ||
op == ZiskOp::RemW.code() ||
op == ZiskOp::Divu.code() ||
op == ZiskOp::Remu.code() ||
op == ZiskOp::DivuW.code() ||
op == ZiskOp::RemuW.code());

self.div_overflow = ((op == ZiskOp::Div.code() || op == ZiskOp::Rem.code()) &&
input_a == 0x8000_0000_0000_0000 &&
input_b == 0xFFFF_FFFF_FFFF_FFFF) ||
((op == DIV_W || op == REM_W) && input_a == 0x8000_0000 && input_b == 0xFFFF_FFFF);
((op == ZiskOp::DivW.code() || op == ZiskOp::RemW.code()) &&
input_a == 0x8000_0000 &&
input_b == 0xFFFF_FFFF);

let [a, b, c, d] = Self::calculate_abcd_from_ab(op, input_a, input_b);
self.a = Self::u64_to_chunks(a);
Expand Down Expand Up @@ -298,26 +302,35 @@ impl ArithOperation {
}

fn calculate_abcd_from_ab(op: u8, a: u64, b: u64) -> [u64; 4] {
match op {
MULU | MULUH => {
let zisk_op = ZiskOp::try_from_code(op).unwrap();
match zisk_op {
ZiskOp::Mulu | ZiskOp::Muluh => {
let c: u128 = a as u128 * b as u128;
[a, b, c as u64, (c >> 64) as u64]
}
MULSUH => {
ZiskOp::Mulsuh => {
let [c, d] = Self::calculate_mulsu(a, b);
[a, b, c, d]
}
MUL | MULH => {
ZiskOp::Mul | ZiskOp::Mulh => {
let [c, d] = Self::calculate_mul(a, b);
[a, b, c, d]
}
MUL_W => [a, b, Self::calculate_mul_w(a, b), 0],
DIVU | REMU => [Self::calculate_divu(a, b), b, a, Self::calculate_remu(a, b)],
DIVU_W | REMU_W => [Self::calculate_divu_w(a, b), b, a, Self::calculate_remu_w(a, b)],
DIV | REM => [Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)],
DIV_W | REM_W => [Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)],
ZiskOp::MulW => [a, b, Self::calculate_mul_w(a, b), 0],
ZiskOp::Divu | ZiskOp::Remu => {
[Self::calculate_divu(a, b), b, a, Self::calculate_remu(a, b)]
}
ZiskOp::DivuW | ZiskOp::RemuW => {
[Self::calculate_divu_w(a, b), b, a, Self::calculate_remu_w(a, b)]
}
ZiskOp::Div | ZiskOp::Rem => {
[Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)]
}
ZiskOp::DivW | ZiskOp::RemW => {
[Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)]
}
_ => {
panic!("Invalid opcode");
panic!("ArithOperation::calculate_abcd_from_ab() Invalid opcode={}", op);
}
}
}
Expand Down Expand Up @@ -351,65 +364,66 @@ impl ArithOperation {
let mut sb = false;
let mut rem = false;

match op {
MULU => {
let zisk_op = ZiskOp::try_from_code(op).unwrap();
match zisk_op {
ZiskOp::Mulu => {
self.main_mul = true;
}
MULUH => {}
MULSUH => {
ZiskOp::Muluh => {}
ZiskOp::Mulsuh => {
sa = true;
}
MUL => {
ZiskOp::Mul => {
sa = true;
sb = true;
self.main_mul = true;
}
MULH => {
ZiskOp::Mulh => {
sa = true;
sb = true;
}
MUL_W => {
ZiskOp::MulW => {
self.m32 = true;
self.sext = ((a * b) & 0xFFFF_FFFF) & 0x8000_0000 != 0;
self.main_mul = true;
}
DIVU => {
ZiskOp::Divu => {
self.div = true;
self.main_div = true;
}
REMU => {
ZiskOp::Remu => {
self.div = true;
rem = true;
}
DIV => {
ZiskOp::Div => {
sa = true;
sb = true;
self.div = true;
self.main_div = true;
}
REM => {
ZiskOp::Rem => {
sa = true;
sb = true;
rem = true;
self.div = true;
}
DIVU_W => {
ZiskOp::DivuW => {
// divu_w, remu_w
self.div = true;
self.m32 = true;
// use a in bus
self.sext = (a & 0x8000_0000) != 0;
self.main_div = true;
}
REMU_W => {
ZiskOp::RemuW => {
// divu_w, remu_w
self.div = true;
self.m32 = true;
rem = true;
// use d in bus
self.sext = (d & 0x8000_0000) != 0;
}
DIV_W => {
ZiskOp::DivW => {
// div_w, rem_w
sa = true;
sb = true;
Expand All @@ -419,7 +433,7 @@ impl ArithOperation {
self.sext = (a & 0x8000_0000) != 0;
self.main_div = true;
}
REM_W => {
ZiskOp::RemW => {
// div_w, rem_w
sa = true;
sb = true;
Expand All @@ -430,7 +444,7 @@ impl ArithOperation {
self.sext = (d & 0x8000_0000) != 0;
}
_ => {
panic!("Invalid opcode");
panic!("ArithOperation::update_flags_and_ranges() Invalid opcode={}", op);
}
}
self.signed = sa || sb;
Expand Down
72 changes: 47 additions & 25 deletions state-machines/arith/src/arith_operation_test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use zisk_core::zisk_ops::*;

use crate::{
arith_constants::*, arith_table_data, ArithOperation, ArithRangeTableHelpers, ArithTableHelpers,
};
use crate::{arith_table_data, ArithOperation, ArithRangeTableHelpers, ArithTableHelpers};

const MIN_N_64: u64 = 0x8000_0000_0000_0000;
const MIN_N_32: u64 = 0x0000_0000_8000_0000;
Expand Down Expand Up @@ -30,8 +28,22 @@ const ALL_VALUES: [u64; 16] = [
MAX_64,
];

const ALL_OPERATIONS: [u8; 14] =
[MUL, MULH, MULSUH, MULU, MULUH, DIVU, REMU, DIV, REM, MUL_W, DIVU_W, REMU_W, DIV_W, REM_W];
const ALL_OPERATIONS: [u8; 14] = [
ZiskOp::Mul.code(),
ZiskOp::Mulh.code(),
ZiskOp::Mulsuh.code(),
ZiskOp::Mulu.code(),
ZiskOp::Muluh.code(),
ZiskOp::Divu.code(),
ZiskOp::Remu.code(),
ZiskOp::Div.code(),
ZiskOp::Rem.code(),
ZiskOp::MulW.code(),
ZiskOp::DivuW.code(),
ZiskOp::RemuW.code(),
ZiskOp::DivW.code(),
ZiskOp::RemW.code(),
];

struct ArithOperationTest {
count: u32,
Expand Down Expand Up @@ -87,30 +99,40 @@ impl ArithOperationTest {
}

fn is_m32_op(op: u8) -> bool {
match op {
MUL | MULH | MULSUH | MULU | MULUH | DIVU | REMU | DIV | REM => false,
MUL_W | DIVU_W | REMU_W | DIV_W | REM_W => true,
_ => panic!("Invalid opcode"),
let zisk_op = ZiskOp::try_from_code(op).unwrap();
match zisk_op {
ZiskOp::Mul |
ZiskOp::Mulh |
ZiskOp::Mulsuh |
ZiskOp::Mulu |
ZiskOp::Muluh |
ZiskOp::Divu |
ZiskOp::Remu |
ZiskOp::Div |
ZiskOp::Rem => false,
ZiskOp::MulW | ZiskOp::DivuW | ZiskOp::RemuW | ZiskOp::DivW | ZiskOp::RemW => true,
_ => panic!("ArithOperationTest::is_m32_op() Invalid opcode={}", op),
}
}
fn calculate_emulator_res(op: u8, a: u64, b: u64) -> (u64, bool) {
match op {
MULU => op_mulu(a, b),
MULUH => op_muluh(a, b),
MULSUH => op_mulsuh(a, b),
MUL => op_mul(a, b),
MULH => op_mulh(a, b),
MUL_W => op_mul_w(a, b),
DIVU => op_divu(a, b),
REMU => op_remu(a, b),
DIVU_W => op_divu_w(a, b),
REMU_W => op_remu_w(a, b),
DIV => op_div(a, b),
REM => op_rem(a, b),
DIV_W => op_div_w(a, b),
REM_W => op_rem_w(a, b),
let zisk_op = ZiskOp::try_from_code(op).unwrap();
match zisk_op {
ZiskOp::Mulu => op_mulu(a, b),
ZiskOp::Muluh => op_muluh(a, b),
ZiskOp::Mulsuh => op_mulsuh(a, b),
ZiskOp::Mul => op_mul(a, b),
ZiskOp::Mulh => op_mulh(a, b),
ZiskOp::MulW => op_mul_w(a, b),
ZiskOp::Divu => op_divu(a, b),
ZiskOp::Remu => op_remu(a, b),
ZiskOp::DivuW => op_divu_w(a, b),
ZiskOp::RemuW => op_remu_w(a, b),
ZiskOp::Div => op_div(a, b),
ZiskOp::Rem => op_rem(a, b),
ZiskOp::DivW => op_div_w(a, b),
ZiskOp::RemW => op_rem_w(a, b),
_ => {
panic!("Invalid opcode");
panic!("ArithOperationTest::calculate_emulator_res() Invalid opcode={}", op);
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions state-machines/arith/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
mod arith;
mod arith_constants;
mod arith_full;
mod arith_operation;
mod arith_range_table;
Expand All @@ -12,7 +11,6 @@ mod arith_table_helpers;
mod arith_operation_test;

pub use arith::*;
pub use arith_constants::*;
pub use arith_full::*;
pub use arith_operation::*;
pub use arith_range_table::*;
Expand Down

0 comments on commit c0c45a5

Please sign in to comment.