diff --git a/.gitignore b/.gitignore index 52498727..63d76060 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ /build /proofs *.pilout -/tmp \ No newline at end of file +/tmp +*.log \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index af6eff06..2088eb94 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,40 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug", + "program": "target/debug/proofman-cli", + //"cargo": { + "args": [ + //"run", + //"--bin", + //"proofman-cli", + "verify-constraints", + "--witness-lib", + "../zisk/target/debug/libzisk_witness.so", + "--rom", + "../zisk/emulator/benches/data/my.elf", + "-i", + "../zisk/emulator/benches/data/input_two_segments.bin", + "--proving-key", + "../zisk/build/provingKey" + ], + //"filter": { + // "name": "proofman_cli", + // "kind": "lib" + //} + //}, + //"args": [], + "cwd": "${workspaceFolder}", + "environment": [ + { "name": "RUSTFLAGS", "value": "-L native=/home/zkronos73/devel/zisk2/pil2-proofman/pil2-stark/lib" } + ], + "sourceLanguages": [ + "rust" + ] + }, { "type": "lldb", "request": "launch", diff --git a/Cargo.lock b/Cargo.lock index 7e3bf0a4..326ad3e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,9 +156,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" [[package]] name = "byteorder" @@ -198,9 +198,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.36" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baee610e9452a8f6f0a1b6194ec09ff9e2d85dea54432acdae41aa0761c95d70" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -248,9 +248,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", @@ -282,9 +282,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "colorchoice" @@ -311,7 +311,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", - "unicode-width", + "unicode-width 0.1.14", "windows-sys 0.52.0", ] @@ -506,9 +506,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "findshlibs" @@ -645,8 +645,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -750,9 +752,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", @@ -956,15 +958,15 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.17.8" +version = "0.17.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" dependencies = [ "console", - "instant", "number_prefix", "portable-atomic", - "unicode-width", + "unicode-width 0.2.0", + "web-time", ] [[package]] @@ -985,15 +987,6 @@ dependencies = [ "str_stack", ] -[[package]] -name = "instant" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" -dependencies = [ - "cfg-if", -] - [[package]] name = "ipnet" version = "2.10.1" @@ -1046,9 +1039,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "7a73e9fe3c49d7afb2ace819fa181a287ce54a0983eda4e0eb05c22f82ffe534" [[package]] name = "jobserver" @@ -1082,9 +1075,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libgit2-sys" @@ -1578,7 +1571,7 @@ dependencies = [ "smallvec", "symbolic-demangle", "tempfile", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1749,9 +1742,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", @@ -1760,26 +1753,29 @@ dependencies = [ "rustc-hash", "rustls", "socket2", - "thiserror", + "thiserror 2.0.3", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", "ring", "rustc-hash", "rustls", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.3", "tinyvec", "tracing", + "web-time", ] [[package]] @@ -1872,7 +1868,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1889,9 +1885,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -2010,9 +2006,9 @@ checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" [[package]] name = "rustix" -version = "0.38.39" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -2023,9 +2019,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" dependencies = [ "once_cell", "ring", @@ -2049,6 +2045,9 @@ name = "rustls-pki-types" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -2090,18 +2089,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.214" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", @@ -2110,9 +2109,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "indexmap", "itoa", @@ -2162,10 +2161,13 @@ name = "sm-arith" version = "0.1.0" dependencies = [ "log", + "num-bigint", "p3-field", + "pil-std-lib", "proofman", "proofman-common", "proofman-macros", + "proofman-util", "rayon", "sm-common", "zisk-core", @@ -2269,6 +2271,7 @@ dependencies = [ name = "sm-rom" version = "0.1.0" dependencies = [ + "itertools 0.13.0", "log", "p3-field", "proofman", @@ -2442,9 +2445,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", "fastrand", @@ -2455,18 +2458,38 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.68" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.3", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] name = "thiserror-impl" -version = "1.0.68" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" dependencies = [ "proc-macro2", "quote", @@ -2666,9 +2689,9 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-width" @@ -2676,6 +2699,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "untrusted" version = "0.9.0" @@ -2857,6 +2886,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.6" diff --git a/Cargo.toml b/Cargo.toml index 5c2ea497..b97f5cb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -#Local development +# Local development # proofman-common = { path = "../pil2-proofman/common" } # proofman-macros = { path = "../pil2-proofman/macros" } # proofman-util = { path = "../pil2-proofman/util" } diff --git a/book/getting_started/proof.md b/book/getting_started/proof.md new file mode 100644 index 00000000..38b672b2 --- /dev/null +++ b/book/getting_started/proof.md @@ -0,0 +1,76 @@ +## Steps to verify constraints or generate proof + +compile pils: +``` +node ../pil2-compiler/src/pil.js pil/fork_0/pil/zisk.pil -I lib/std/pil -o pil/fork_0/pil/zisk.pilout +``` + +generate "structs" for different airs: +`(cd ../pil2-proofman; cargo run --bin proofman-cli pil-helpers --pilout ../zisk/pil/fork_0/pil/zisk.pilout --path ../zisk/pil/fork_0/src/ -o)` + +prepare "fast tools" (only first time): +`(cd ../zkevm-prover && git switch develop_rust_lib && git submodule init && git submodule update && make -j bctree && make starks_lib -j)` + +setup for pil, this step is necessary **only when pil change**: +`node ../pil2-proofman-js/src/main_setup.js -a pil/fork_0/pil/zisk.pilout -b build -t ../zkevm-prover/build/bctree` + +this step should be done once and is optional. Edit file pil2-proofman/provers/starks-lib-c/Cargo.toml to remove "no_lib_link" from line 12: +`nano ../pil2-proofman/provers/starks-lib-c/Cargo.toml` + +compile witness computation library (libzisk_witness.so). If you haven't nightly mode as default, must add +nightly when do build. +`cargo build --release` + +In the following steps to verify constraints or generate prove, select one of these inputs: +- input.bin: large number of sha +- input_one_segment.bin: only one sha +- input_two_segments.bin: 512 shas + +To **verify constraints** use: +`(cd ../pil2-proofman; cargo run --release --bin proofman-cli verify-constraints --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input.bin --proving-key ../zisk/build/provingKey)` + +To **generate proof** use: +`(cd ../pil2-proofman; cargo run --release --bin proofman-cli verify-constraints --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input.bin --proving-key ../zisk/build/provingKey)` + +## Steps to compile a verifiable rust program + +### Setup +Install qemu: +`sudo apt-get install qemu-system` +Add tokens to access repos: +``` +export GITHUB_ACCESS_TOKEN=.... +export ZISK_TOKEN=.... +``` +### Create new hello_world project +Create project with toolchain: +```bash +cargo-zisk sdk new hello_world +cd hello_world +``` + +Compile and execute in **riskv mode**: +`cargo-zisk run --release` + +Compile and execute in **zisk mode**: +`cargo-zisk run --release --sim` + +Execute with ziskemu: +`ziskemu -i build/input.bin -x -e target/riscv64ima-polygon-ziskos-elf/release/fibonacci` + +### Update toolchain +``` +ziskup +``` +If ziskup fails, could update ziskemu manually. + +### Update ziskemu manually +```bash +cd zisk +git pull +cargo install --path emulator +cp ~/.cargo/bin/ziskemu ~/.zisk/bin/ +``` + +```bash +ziskemu -i build/input.bin -x -e target/riscv64ima-polygon-ziskos-elf/debug/fibonacci +``` diff --git a/core/src/zisk_ops.rs b/core/src/zisk_ops.rs index 89d9c02e..661b8c19 100644 --- a/core/src/zisk_ops.rs +++ b/core/src/zisk_ops.rs @@ -247,19 +247,19 @@ define_ops! { (Or, "or", Binary, 77, 0x21, opc_or, op_or), (Xor, "xor", Binary, 77, 0x22, opc_xor, op_xor), (Mulu, "mulu", ArithAm32, 97, 0xb0, opc_mulu, op_mulu), - (Mul, "mul", ArithAm32, 97, 0xb1, opc_mul, op_mul), - (MulW, "mul_w", ArithAm32, 44, 0xb5, opc_mul_w, op_mul_w), - (Muluh, "muluh", ArithAm32, 97, 0xb8, opc_muluh, op_muluh), - (Mulh, "mulh", ArithAm32, 97, 0xb9, opc_mulh, op_mulh), - (Mulsuh, "mulsuh", ArithAm32, 97, 0xbb, opc_mulsuh, op_mulsuh), - (Divu, "divu", ArithAm32, 174, 0xc0, opc_divu, op_divu), - (Div, "div", ArithAm32, 174, 0xc1, opc_div, op_div), - (DivuW, "divu_w", ArithA32, 136, 0xc4, opc_divu_w, op_divu_w), - (DivW, "div_w", ArithA32, 136, 0xc5, opc_div_w, op_div_w), - (Remu, "remu", ArithAm32, 174, 0xc8, opc_remu, op_remu), - (Rem, "rem", ArithAm32, 174, 0xc9, opc_rem, op_rem), - (RemuW, "remu_w", ArithA32, 136, 0xcc, opc_remu_w, op_remu_w), - (RemW, "rem_w", ArithA32, 136, 0xcd, opc_rem_w, op_rem_w), + (Muluh, "muluh", ArithAm32, 97, 0xb1, opc_muluh, op_muluh), + (Mulsuh, "mulsuh", ArithAm32, 97, 0xb3, opc_mulsuh, op_mulsuh), + (Mul, "mul", ArithAm32, 97, 0xb4, opc_mul, op_mul), + (Mulh, "mulh", ArithAm32, 97, 0xb5, opc_mulh, op_mulh), + (MulW, "mul_w", ArithAm32, 44, 0xb6, opc_mul_w, op_mul_w), + (Divu, "divu", ArithAm32, 174, 0xb8, opc_divu, op_divu), + (Remu, "remu", ArithAm32, 174, 0xb9, opc_remu, op_remu), + (Div, "div", ArithAm32, 174, 0xba, opc_div, op_div), + (Rem, "rem", ArithAm32, 174, 0xbb, opc_rem, op_rem), + (DivuW, "divu_w", ArithA32, 136, 0xbc, opc_divu_w, op_divu_w), + (RemuW, "remu_w", ArithA32, 136, 0xbd, opc_remu_w, op_remu_w), + (DivW, "div_w", ArithA32, 136, 0xbe, opc_div_w, op_div_w), + (RemW, "rem_w", ArithA32, 136, 0xbf, opc_rem_w, op_rem_w), (Minu, "minu", Binary, 77, 0x09, opc_minu, op_minu), (Min, "min", Binary, 77, 0x0a, opc_min, op_min), (MinuW, "minu_w", Binary, 77, 0x19, opc_minu_w, op_minu_w), diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index a56c35d0..76b11fab 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -893,7 +893,8 @@ impl<'a> Emu<'a> { m32: F::from_bool(inst.m32), addr1: F::from_canonical_u64(addr1), __debug_operation_bus_enabled: F::from_bool( - inst.op_type == ZiskOperationType::Binary || + inst.op_type == ZiskOperationType::Arith || + inst.op_type == ZiskOperationType::Binary || inst.op_type == ZiskOperationType::BinaryE, ), } diff --git a/pil/operations.pil b/pil/operations.pil new file mode 100644 index 00000000..d857a854 --- /dev/null +++ b/pil/operations.pil @@ -0,0 +1,50 @@ +const int OP_FLAG = 0x00; +const int OP_COPYB = 0x01; +const int OP_SIGNEXTEND_B = 0x02; +const int OP_SIGNEXTEND_H = 0x03; +const int OP_SIGNEXTEND_W = 0x04; +const int OP_ADD = 0x10; +const int OP_ADD_W = 0x14; +const int OP_SUB = 0x20; +const int OP_SUB_W = 0x24; +const int OP_SLL = 0x30; +const int OP_SLL_W = 0x34; +const int OP_SRA = 0x40; +const int OP_SRL = 0x41; +const int OP_SRA_W = 0x44; +const int OP_SRL_W = 0x45; +const int OP_EQ = 0x50; +const int OP_EQ_W = 0x54; +const int OP_LTU = 0x60; +const int OP_LT = 0x61; +const int OP_LTU_W = 0x64; +const int OP_LT_W = 0x65; +const int OP_LEU = 0x70; +const int OP_LE = 0x71; +const int OP_LEU_W = 0x74; +const int OP_LE_W = 0x75; +const int OP_AND = 0x80; +const int OP_OR = 0x90; +const int OP_XOR = 0xA0; +const int OP_MULU = 0xB0; +const int OP_MUL = 0xB1; +const int OP_MUL_W = 0xB5; +const int OP_MULUH = 0xB8; +const int OP_MULH = 0xB9; +const int OP_MULSUH = 0xBB; +const int OP_DIVU = 0xC0; +const int OP_DIV = 0xC1; +const int OP_DIVU_W = 0xC4; +const int OP_DIV_W = 0xC5; +const int OP_REMU = 0xC8; +const int OP_REM = 0xC9; +const int OP_REMU_W = 0xCC; +const int OP_REM_W = 0xCD; +const int OP_MINU = 0xD0; +const int OP_MIN = 0xD1; +const int OP_MINU_W = 0xD4; +const int OP_MIN_W = 0xD5; +const int OP_MAXU = 0xE0; +const int OP_MAX = 0xE1; +const int OP_MAXU_W = 0xE4; +const int OP_MAX_W = 0xE5; diff --git a/pil/src/lib.rs b/pil/src/lib.rs index 5d31b15a..aee8bab5 100644 --- a/pil/src/lib.rs +++ b/pil/src/lib.rs @@ -3,7 +3,6 @@ mod pil_helpers; pub use pil_helpers::*; //TODO To be removed when ready in ZISK_PIL -pub const ARITH_AIRGROUP_ID: usize = 101; pub const ARITH32_AIR_IDS: &[usize] = &[4, 5]; pub const ARITH64_AIR_IDS: &[usize] = &[6]; pub const ARITH3264_AIR_IDS: &[usize] = &[7]; diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 6a7829f7..9a796335 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -14,15 +14,21 @@ pub const MAIN_AIR_IDS: &[usize] = &[0]; pub const ROM_AIR_IDS: &[usize] = &[1]; -pub const BINARY_AIR_IDS: &[usize] = &[2]; +pub const ARITH_AIR_IDS: &[usize] = &[2]; -pub const BINARY_TABLE_AIR_IDS: &[usize] = &[3]; +pub const ARITH_TABLE_AIR_IDS: &[usize] = &[3]; -pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[4]; +pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[4]; -pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[5]; +pub const BINARY_AIR_IDS: &[usize] = &[5]; -pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[6]; +pub const BINARY_TABLE_AIR_IDS: &[usize] = &[6]; + +pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[7]; + +pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[8]; + +pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[9]; pub struct Pilout; @@ -34,6 +40,9 @@ impl Pilout { air_group.add_air(Some("Main"), 2097152); air_group.add_air(Some("Rom"), 1048576); + air_group.add_air(Some("Arith"), 2097152); + air_group.add_air(Some("ArithTable"), 128); + air_group.add_air(Some("ArithRangeTable"), 4194304); air_group.add_air(Some("Binary"), 2097152); air_group.add_air(Some("BinaryTable"), 4194304); air_group.add_air(Some("BinaryExtension"), 2097152); diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index aea9c202..32cfe09f 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -11,6 +11,18 @@ trace!(RomRow, RomTrace { line: F, a_offset_imm0: F, a_imm1: F, b_offset_imm0: F, b_imm1: F, ind_width: F, op: F, store_offset: F, jmp_offset1: F, jmp_offset2: F, flags: F, multiplicity: F, }); +trace!(ArithRow, ArithTrace { + carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, sext: F, m32: F, div: F, fab: F, na_fb: F, nb_fa: F, debug_main_step: F, main_div: F, main_mul: F, signed: F, div_by_zero: F, div_overflow: F, inv_sum_all_bs: F, op: F, bus_res1: F, multiplicity: F, range_ab: F, range_cd: F, +}); + +trace!(ArithTableRow, ArithTableTrace { + multiplicity: F, +}); + +trace!(ArithRangeTableRow, ArithRangeTableTrace { + multiplicity: F, +}); + trace!(BinaryRow, BinaryTrace { m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F, }); diff --git a/pil/zisk.pil b/pil/zisk.pil index 91feb604..0e97aeb6 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -6,6 +6,7 @@ require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" require "binary/pil/binary_extension.pil" require "binary/pil/binary_extension_table.pil" +require "arith/pil/arith.pil" // require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; @@ -13,6 +14,9 @@ airgroup Zisk { Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); Rom(N: 2**20); // Mem(N: 2**21, RC: 2); + Arith(N: 2**21, operation_bus_id: OPERATION_BUS_ID); + ArithTable(); + ArithRangeTable(); Binary(N: 2**21, operation_bus_id: OPERATION_BUS_ID); BinaryTable(disable_fixed: 0); BinaryExtension(N: 2**21, operation_bus_id: OPERATION_BUS_ID); diff --git a/state-machines/arith/Cargo.toml b/state-machines/arith/Cargo.toml index 23f8d0b2..7c0859b8 100644 --- a/state-machines/arith/Cargo.toml +++ b/state-machines/arith/Cargo.toml @@ -5,17 +5,22 @@ edition = "2021" [dependencies] zisk-core = { path = "../../core" } +zisk-pil = { path="../../pil" } sm-common = { path = "../common" } -zisk-pil = { path = "../../pil" } p3-field = { workspace=true } proofman-common = { workspace = true } proofman-macros = { workspace = true } +proofman-util = { workspace = true } proofman = { workspace = true } +pil-std-lib = { workspace = true } log = { workspace = true } rayon = { workspace = true } +num-bigint = { workspace = true } + [features] default = [] +generate_code_arith_range_table = [] no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] \ No newline at end of file diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index e69de29b..1a830cb5 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -0,0 +1,320 @@ +require "std_lookup.pil" +require "std_range_check.pil" +require "operations.pil" +require "arith_table.pil" +require "arith_range_table.pil" + +// full mul_64 full_32 mul_32 +// TOTAL 88 77 57 44 + +const int OP_LT_ABS = 0x9F; + +airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_result = 0) { + + const int CHUNK_SIZE = 2**16; + const int CHUNKS_INPUT = 4; + const int CHUNKS_OP = CHUNKS_INPUT * 2; + + col witness carry[CHUNKS_OP - 1]; + col witness a[CHUNKS_INPUT]; + col witness b[CHUNKS_INPUT]; + col witness c[CHUNKS_INPUT]; + col witness d[CHUNKS_INPUT]; + + col witness na; // a is negative + col witness nb; // b is negative + col witness nr; // rem is negative + col witness np; // prod is negative + col witness sext; // sign extend for 32 bits result + + col witness m32; // 32 bits operation + col witness div; // division operation (div,rem) + + col witness fab; // fab, to decrease degree of intermediate products a * b + // fab = 1 if sign of a,b are the same + // fab = -1 if sign of a,b are different + + col witness na_fb; + col witness nb_fa; + + col witness debug_main_step; // only for debug +/* + col witness secondary; // op_index: 0 => first result, 1 => second result; + secondary * (secondary - 1) === 0; +*/ + col witness main_div; + col witness main_mul; + col witness signed; + + col witness div_by_zero; + col witness div_overflow; + + main_div * (main_div - 1) === 0; + main_mul * (main_mul - 1) === 0; + main_mul * main_div === 0; + signed * (1 - signed) === 0; + div_by_zero * (1 - div_by_zero) === 0; + div_overflow * (1 - div_overflow) === 0; + + // factor ab € {-1, 1} + fab === 1 - 2 * na - 2 * nb + 4 * na * nb; + na_fb === na * (1 - 2 * nb); + nb_fa === nb * (1 - 2 * na); + + expr sum_all_bs = 0; + for (int i = 0; i < length(b); ++i) { + div_by_zero * b[i] === 0; // forces b must be zero when div_by_zero + sum_all_bs = sum_all_bs + b[i]; // all b are values of 16 bits (verified by range_check) + } + + // when div_by_zero, a it's free, with this force a must be 0xFFFF + div_by_zero * (a[0] - 0xFFFF) === 0; + div_by_zero * (a[1] - 0xFFFF) === 0; + div_by_zero * (a[2] - (1 - m32) * 0xFFFF) === 0; + div_by_zero * (a[3] - (1 - m32) * 0xFFFF) === 0; + + // when div_by_zero, a it's free, with this force a must be 0xFFFF + div_overflow * (b[0] - 0xFFFF) === 0; + div_overflow * (b[1] - 0xFFFF) === 0; + div_overflow * (b[2] - (1 - m32) * 0xFFFF) === 0; + div_overflow * (b[3] - (1 - m32) * 0xFFFF) === 0; + + // when div_by_zero, a it's free, with this force a must be 0xFFFF + div_overflow * c[0] === 0; + div_overflow * (c[1] - m32 * 0x8000) === 0; + div_overflow * c[2] === 0; + div_overflow * (c[3] - (1 - m32) * 0x8000) === 0; + + // b != 0 <==> sum_all_bs != 0 + col witness inv_sum_all_bs; + + // div = 0 => div_by_zero must be 0 => 0 (no need calculate inverse) + // div = 1 and div_by_zero = 0 => 1 calculate inverse to demostrate b != 0 + // div = 1 and div_by_zero = 1 => 0 (no need calculate inverse) + (div - div_by_zero) * (1 - inv_sum_all_bs * sum_all_bs) === 0; + + // div_by_zero only active for divisions + div_by_zero * (1 - div) === 0; + + // div_overflow only active for signed divisions + div_overflow * (1 - div) === 0; + div_overflow * (1 - signed) === 0; + + div_overflow * div_by_zero === 0; + div_by_zero * div_overflow === 0; + + const expr eq[CHUNKS_OP]; + + // NOTE: Equations with m32 for multiplication not exists, because mul m32 it's an unsigned operation. + // In internal equations, it's same than unsigned mul 64 where high part of a and b are zero + + // abs(x) x >= 0 ➜ nx == 0 ➜ x + // x < 0 ➜ nx == 1 ➜ 2^64 - x + // + // abs(x,nx) = nx * (2^64 - 2 * x) + x = 2^64 * nx - 2 * nx * x + x + // + // chunk[0] = x[0] - 2 * nx + x[0] // 2^0 + // chunk[1] = x[1] - 2 * nx + x[1] // 2^16 + // chunk[2] = x[2] - 2 * nx + x[2] // 2^24 + // chunk[3] = x[3] - 2 * nx + x[3] // 2^48 + // chunk[4] = nx // 2^64 + // + // or chunk[3] = x[3] - 2 * nx + x[3] + 2^16 * nx + // chunk[4] = 0 + // + // dual use of d, on multiplication d is high part of result, while in division d + // is the remainder. Selector of these two uses is div or nr (because nr = 0 for div = 0) + // + // div = 0 ➜ a * b = 2^64 * d + c ➜ a * b - 2^64 * d - c === 0 + // div = 1 ➜ a * b + d = c ➜ a * b - c + d === 0 + // + // eq = a * b + c - div * d - (1 - div) * 2^64 * d + + eq[0] = fab * a[0] * b[0] + - c[0] // ⎫ np == 0 ➜ - c + + 2 * np * c[0] // ⎭ np == 1 ➜ - c + 2c = c + + div * d[0] // ⎫ div == 0 ➜ nr = 0 ➜ 0 + - 2 * nr * d[0]; // ⎥ div == 1 and nr == 0 ➜ d + // ⎭ div == 1 and nr == 1 ➜ d - 2d = -d + + eq[1] = fab * a[1] * b[0] + + fab * a[0] * b[1] + - c[1] // ⎫ np == 0 ➜ - c + + 2 * np * c[1] // ⎭ np == 1 ➜ c + + div * d[1] // ⎫ div == 1 ➜ d or -d + - 2 * nr * d[1]; // ⎭ div == 0 ➜ 0 + + eq[2] = fab * a[2] * b[0] + + fab * a[1] * b[1] + + fab * a[0] * b[2] + + a[0] * nb_fa * m32 // ⎫ sign contribution when m32 + + b[0] * na_fb * m32 // ⎭ + - c[2] // ⎫ np == 0 ➜ - c + + 2 * np * c[2] // ⎭ np == 1 ➜ c + + div * d[2] // ⎫ div == 1 ➜ d or -d + - 2 * nr * d[2] // ⎭ div == 0 ➜ 0 + - np * div * m32 // m32 == 1 and np == 1 ➜ -2^32 (global) or -1 (in 3rd chunk) + + nr * m32; // m32 == 1 and nr == 1 ➜ div == 1 ➜ 2^32 (global) or 1 (in 3rd chunk) + + eq[3] = fab * a[3] * b[0] + + fab * a[2] * b[1] + + fab * a[1] * b[2] + + fab * a[0] * b[3] // NOTE: m32 => high part is 0 + + a[1] * nb_fa * m32 // ⎫ sign contribution when m32 + + b[1] * na_fb * m32 // ⎭ + - c[3] // ⎫ np == 0 ➜ - c + + 2 * np * c[3] // ⎭ np == 1 ➜ c + + div * d[3] // ⎫ div == 1 ➜ d or -d + - 2 * nr * d[3]; // ⎭ div == 0 ➜ 0 + + eq[4] = fab * a[3] * b[1] + + fab * a[2] * b[2] + + fab * a[1] * b[3] + + na * nb * m32 + // + b[0] * na * (1 - 2 * nb) + // + a[0] * nb * (1 - 2 * na) + + b[0] * na_fb * (1 - m32) + + a[0] * nb_fa * (1 - m32) + + - np * m32 * (1 - div) // + - np * (1 - m32) * div // 2^64 (np) + + nr * (1 - m32) // 2^64 (nr) + + - d[0] * (1 - div) // 3 degree + + 2 * np * d[0] * (1 - div); // 3 degree + + eq[5] = fab * a[3] * b[2] // 3 degree + + fab * a[2] * b[3] // 3 degree + + a[1] * nb_fa * (1 - m32) + + b[1] * na_fb * (1 - m32) + - d[1] * (1 - div) + + d[1] * 2 * np * (1 - div); + + eq[6] = fab * a[3] * b[3] // 3 degree + + a[2] * nb_fa * (1 - m32) + + b[2] * na_fb * (1 - m32) + - d[2] * (1 - div) + + 2 * np * d[2] * (1 - div); // 3 degree + + eq[7] = CHUNK_SIZE * na * nb * (1 - m32) + + a[3] * nb_fa * (1 - m32) + + b[3] * na_fb * (1 - m32) + - CHUNK_SIZE * np * (1 - div) * (1 - m32) // 3 degree + // - CHUNK_SIZE * np * (1 - div) + - d[3] * (1 - div) + + 2 * np * d[3] * (1 - div); // 3 degree + + eq[0] - carry[0] * CHUNK_SIZE === 0; + for (int index = 1; index < (CHUNKS_OP - 1); ++index) { + eq[index] + carry[index-1] - carry[index] * CHUNK_SIZE === 0; + } + eq[CHUNKS_OP-1] + carry[CHUNKS_OP-2] === 0; + + // binary contraint + div * (1 - div) === 0; + m32 * (1 - m32) === 0; + na * (1 - na) === 0; + nb * (1 - nb) === 0; + nr * (1 - nr) === 0; + np * (1 - np) === 0; + sext * (1 - sext) === 0; + + col witness op; + + // div m32 sa sb primary secondary opcodes na nb np nr sext(c) + // ------------------------------------------------------------------------------------- + // 0 0 0 0 mulu muluh 0xb0 176 0xb1 177 =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh 0xb2 - 0xb3 179 a3 =0 d3 =0 =0 =0 a3, d3 + // 0 0 1 1 mul mulh 0xb4 180 0xb5 181 a3 b3 d3 =0 =0 =0 a3,b3, d3 + // 0 1 0 0 mul_w *n/a* 0xb6 182 0xb7 - =0 =0 =0 =0 c1 =0 + + // div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2) + // ------------------------------------------------------------------------------------------ + // 1 0 0 0 divu remu 0xb8 184 0xb9 185 =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem 0xba 186 0xbb 187 a3 b3 c3 d3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w 0xbc 188 0xbd 189 =0 =0 =0 =0 a1 d1 a1 ,d1 + // 1 1 1 1 div_w rem_w 0xbe 190 0xbf 191 a1 b1 c1 d1 a1 d1 a1,b1,c1,d1 + + // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb + // (*2) sext affects to 32 bits result (bus), but in divisions a is used as result + // see 5 previous constraints. + // =0 means forced to zero by previous constraints + + // bus result primary secondary + // ---------------------------------- + // mul (mulh) c d + // div (remu) a d + + const expr secondary = 1 - main_mul - main_div; + const expr bus_a0 = div * (c[0] + c[1] * CHUNK_SIZE) + (1 - div) * (a[0] + a[1] * CHUNK_SIZE); + const expr bus_a1 = div * (c[2] + c[3] * CHUNK_SIZE) + (1 - div) * (a[2] + a[3] * CHUNK_SIZE); + + const expr bus_b0 = b[0] + b[1] * CHUNK_SIZE; + const expr bus_b1 = b[2] + b[3] * CHUNK_SIZE; + + const expr bus_res0 = secondary * (d[0] + d[1] * CHUNK_SIZE) + + main_mul * (c[0] + c[1] * CHUNK_SIZE) + + main_div * (a[0] + a[1] * CHUNK_SIZE); + + const expr bus_res1_64 = (secondary * (d[2] + d[3] * CHUNK_SIZE) + + main_mul * (c[2] + c[3] * CHUNK_SIZE) + + main_div * (a[2] + a[3] * CHUNK_SIZE)); + col witness bus_res1; + + bus_res1 === sext * 0xFFFF_FFFF + (1 - m32) * bus_res1_64; + + m32 * bus_a1 === 0; + m32 * bus_b1 === 0; + + col witness multiplicity; + + lookup_proves(operation_bus_id, [debug_main_step, + op, + bus_a0, bus_a1, + bus_b0, bus_b1, + bus_res0, bus_res1, + div_by_zero /*+ div_overflow*/], mul: multiplicity); + + // TODO: remainder check + // lookup_assumes(operation_bus_id, [debug_main_step, signed * (OP_LT_ABS - OP_LT) + OP_LT, + // (d[0] + CHUNK_SIZE * d[1]), (d[2] + CHUNK_SIZE * d[3]) + m32 * nr * 0xFFFFFFFF, + // (b[0] + CHUNK_SIZE * b[1]), (b[2] + CHUNK_SIZE * b[3]) + m32 * nb * 0xFFFFFFFF, + // 1, 0, 1], sel: div); + + for (int index = 0; index < length(carry); ++index) { + arith_range_table_assumes(ARITH_RANGE_CARRY, carry[index]); // TODO: review carry range + } + + col witness range_ab; + col witness range_cd; + + arith_table_assumes(op, m32, div, na, nb, np, nr, sext, div_by_zero, div_overflow, main_mul, + main_div, signed, range_ab, range_cd); + + const expr range_a3 = range_ab; + const expr range_a1 = range_ab + 26; + const expr range_b3 = range_ab + 17; + const expr range_b1 = range_ab + 9; + + const expr range_c3 = range_cd; + const expr range_c1 = range_cd + 26; + const expr range_d3 = range_cd + 17; + const expr range_d1 = range_cd + 9; + + arith_range_table_assumes(range_a1, a[1]); + arith_range_table_assumes(range_b1, b[1]); + arith_range_table_assumes(range_c1, c[1]); + arith_range_table_assumes(range_d1, d[1]); + arith_range_table_assumes(range_a3, a[3]); + arith_range_table_assumes(range_b3, b[3]); + arith_range_table_assumes(range_c3, c[3]); + arith_range_table_assumes(range_d3, d[3]); + + // loop for range checks index 0, 2 + for (int index = 0; index < 2; ++index) { + arith_range_table_assumes(ARITH_RANGE_16_BITS, a[2 * index]); + arith_range_table_assumes(ARITH_RANGE_16_BITS, b[2 * index]); + arith_range_table_assumes(ARITH_RANGE_16_BITS, c[2 * index]); + arith_range_table_assumes(ARITH_RANGE_16_BITS, d[2 * index]); + } +} diff --git a/state-machines/arith/pil/arith_32.pil b/state-machines/arith/pil/arith_32.pil deleted file mode 100644 index e69de29b..00000000 diff --git a/state-machines/arith/pil/arith_3264.pil b/state-machines/arith/pil/arith_3264.pil deleted file mode 100644 index e69de29b..00000000 diff --git a/state-machines/arith/pil/arith_64.pil b/state-machines/arith/pil/arith_64.pil deleted file mode 100644 index e69de29b..00000000 diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil new file mode 100644 index 00000000..00837c51 --- /dev/null +++ b/state-machines/arith/pil/arith_range_table.pil @@ -0,0 +1,80 @@ +require "std_lookup.pil" +require "operations.pil" + +const int ARITH_RANGE_TABLE_ID = 330; +const int ARITH_RANGE_CARRY = 100; +const int ARITH_RANGE_16_BITS = 0; + +airtemplate ArithRangeTable(int N = 2**22) { + + // a3 a1 b3 b1 + // rid c3 c1 d3 d1 range 2^16 2^15 notes + // --- -- -- -- -- ----- ---- ---- ------------------------- + // 0 F F F F ab cd 4 0 + // 1 F F + F cd 3 1 b3 sign => a3 sign + // 2 F F - F cd 3 1 b3 sign => a3 sign + // 3 + F F F ab 3 1 c3 sign => d3 sign + // 4 + F + F ab cd 2 2 + // 5 + F - F ab cd 2 2 + // 6 - F F F ab 3 1 c3 sign => d3 sign + // 7 - F + F ab cd 2 2 + // 8 - F - F ab cd 2 2 + // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 11 F + F F ab cd 3 1 *a1 for sext/divu + // 12 F + F + ab cd 2 2 + // 13 F + F - ab cd 2 2 + // 14 F - F F ab cd 3 1 *a1 for sext/divu + // 15 F - F + ab cd 2 2 + // 16 F - F - ab cd 2 2 + // ---- ---- + + // COL COMPRESSION + // + // + // + // 0: F F F + + + - - - F F F F F F F F offset: 0 + // 1: F F F F F F F F F F F + + + - - - offset: 26 + // 2: F + - F + - F + - F F F F F F F F offset: 17 + // 3: F F F F F F F F F + - F + - F + - offset: 9 + // -------------------------------------------------------------------------------------- + // F F F + + + - - - F F F F F F F F F + - F + - F + - F F F F F F F F F F F + + + - - - + // + + // 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 + // + // F F F + + + - - - F F F F F F F F F + - F + - F + - F F F F F F F F F F F + + + - - - + // + // 25:FULL + 9:POS + 9:NEG = 34 * 2^16 = 2^21 + 2^17 + // + // a3 c3 [range, 0] => [range] + // a1 c1 [range, 1] => [range + 26] + // b3 d3 [range, 2] => [range + 17] + // b1 d1 [range, 3] => [range + 9] + // + // [-(2^19+2^18+2^16-1)...(2^19+2^18+2^16)] range check carry + + const int FULL = 2**16; + const int POS = 2**15; + const int NEG = 2**15; + + col fixed RANGE_ID = [0:FULL..2:FULL, 9:FULL..17:FULL, 20:FULL, 23:FULL, 26:FULL..36:FULL, // 25 FULL + 3:POS..5:POS, 18:POS, 21:POS, 24:POS, 37:POS..39:POS, // 9 POS + 6:NEG..8:NEG, 19:NEG, 22:NEG, 25:NEG, 40:NEG..42:NEG, // 9 NEG + ARITH_RANGE_CARRY...]; + + col fixed RANGE_VALUES = [[0x0000..0xFFFF]:25, + [0x0000..0x7FFF]:9, + [0x8000..0xFFFF]:9, + [-0xEFFFF..0xF0000]]; + + + col witness multiplicity; + + lookup_proves(ARITH_RANGE_TABLE_ID, [RANGE_ID, RANGE_VALUES], multiplicity); +} + +function arith_range_table_assumes(const expr range_type, const expr value, const expr sel = 1) { + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_type, value], sel:sel); +} diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil new file mode 100644 index 00000000..6788f7de --- /dev/null +++ b/state-machines/arith/pil/arith_table.pil @@ -0,0 +1,283 @@ +require "std_lookup.pil" + +const int ARITH_TABLE_ID = 331; + +airtemplate ArithTable(int N = 2**7, int generate_table = 1) { + + // div m32 sa sb primary secondary opcodes na nb np nr sext(c) + // ----------------------------------------------------------------------------------- + // 0 0 0 0 mulu muluh 0xb0 176 0xb1 177 =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh 0xb2 - 0xb3 179 a3 =0 d3 =0 =0 =0 a3, d3 + // 0 0 1 1 mul mulh 0xb4 180 0xb5 181 a3 b3 d3 =0 =0 =0 a3,b3, d3 + // 0 1 0 0 mul_w *n/a* 0xb6 182 0xb7 - =0 =0 =0 =0 c1 =0 + + // div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2) + // ------------------------------------------------------------------------------------ + // 1 0 0 0 divu remu 0xb8 184 0xb9 185 =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem 0xba 186 0xbb 187 a3 b3 c3 d3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w 0xbc 188 0xbd 189 =0 =0 =0 =0 a1 d1 a1 ,d1 + // 1 1 1 1 div_w rem_w 0xbe 190 0xbf 191 a1 b1 c1 d1 a1 d1 a1,b1,c1,d1 + + const int OPS[14] = [0xb0, 0xb1, 0xb3, 0xb4, 0xb5, 0xb6, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf]; + + col fixed OP; + col fixed FLAGS; + col fixed RANGE_AB; + col fixed RANGE_CD; + string code = ""; + + int index = 0; + int aborted = 0; + + if (generate_table) { + int air.op2row[2048]; + for (int i = 0; i < 2048; ++i) { + op2row[i] = 255; + } + } + + for (int opcode = 0xb0; opcode <= 0xbf; ++opcode) { + if (opcode == 0xb2 || opcode == 0xb7) { + continue; + } + int m32 = 0; // 32 bits operation + int div = 0; // division operation (div,rem) + int sa = 0; + int sb = 0; + int main_mul = 0; + int main_div = 0; + string opname = ""; + switch (opcode) { + case 0xb0: + opname = "mulu"; + main_mul = 1; + case 0xb1: + opname = "mulh"; + case 0xb3: + opname = "mulsuh"; + sa = 1; + case 0xb4: + opname = "mul"; + sa = 1; + sb = 1; + main_mul = 1; + case 0xb5: + opname = "mulh"; + sa = 1; + sb = 1; + case 0xb6: + opname = "mul_w"; + m32 = 1; + main_mul = 1; + case 0xb8: + opname = "divu"; + div = 1; + main_div = 1; + case 0xb9: + opname = "remu"; + div = 1; + case 0xba: + opname = "div"; + sa = 1; + sb = 1; + div = 1; + main_div = 1; + case 0xbb: + opname = "rem"; + sa = 1; + sb = 1; + div = 1; + case 0xbc: + opname = "divu_w"; + div = 1; + m32 = 1; + main_div = 1; + case 0xbd: + opname = "remu_w"; + div = 1; + m32 = 1; + case 0xbe: + opname = "div_w"; + sa = 1; + sb = 1; + div = 1; + m32 = 1; + main_div = 1; + case 0xbf: + opname = "rem_w"; + sa = 1; + sb = 1; + div = 1; + m32 = 1; + } + + for (int icase = 0; icase < 128; ++icase) { + const int na = (0x01 & icase) ? 1 : 0; + const int nb = (0x02 & icase) ? 1 : 0; + const int np = (0x04 & icase) ? 1 : 0; + const int nr = (0x08 & icase) ? 1 : 0; + const int sext = (0x10 & icase) ? 1 : 0; + const int div_by_zero = (0x20 & icase) ? 1 : 0; + const int div_overflow = (0x40 & icase) ? 1 : 0; + + const int signed = (sa || sb) ? 1 : 0; + + // division by zero (dividend: x, divisor: 0) + // + // DIV,DIVU 0xFFFF_FFFF_FFFF_FFFF + // REM,REMU x + // DIV_W,DIVU_W 0xFFFF_FFFF_FFFF_FFFF + // REM_W,REMU_W x + + // division overflow 64 (divend: 0x8000_0000_0000_0000, divisor: 0xFFFF_FFFF_FFFF_FFFF) + // + // DIV 0x8000_0000_0000_0000 + // REM 0 + + // division overflow 32 (divend: 0x8000_0000, divisor: 0xFFFF_FFFF) + // + // DIV_W 0xFFFF_FFFF_8000_0000 + // REM_W 0 + + // div_by_zero + // signed:1 => na:1 nb:0 np = nr (0,1) + // signed:0 => na:0 nb:0 np:0 nr:0 + + // div_overflow + // signed:1 => na:1 nb:1 np:1 nr:0 sext:0 + + if (div_by_zero && (!div || nb || np != nr || signed != na)) continue; + if (div_by_zero && main_div && m32 && !sext) continue; + if (div_overflow && (!div || !signed || !na || !nb || !np || nr)) continue; + if (sext && !m32) continue; + if (nr && !div) continue; + if (na && !sa) continue; + if (nb && !sb) continue; + if (np && !sa && !sb) continue; + if (nr && !sa && !sb) continue; + if (np && na == nb && !div) continue; + if (np && !na && !nb && !nr && div) continue; + if (na && !nb && !nr && !np && div && !div_by_zero) continue; + if (np && na && nb && !div_overflow) continue; + if (!np & nr) continue; + if (m32 && signed && main_div && na != sext) continue; + if (m32 && signed && div && !main_div && nr != sext) continue; + + int range_a1 = 0; + int range_b1 = 0; + int range_c1 = 0; + int range_d1 = 0; + int range_a3 = 0; + int range_b3 = 0; + int range_c3 = 0; + int range_d3 = 0; + + if (m32) { + if (sa) { + range_a1 = na ? 2 : 1; + } else if (main_div) { + range_a1 = sext ? 2 : 1; + } + if (sb) { + range_b1 = nb ? 2 : 1; + } + if (!div) { + range_c1 = sext ? 2 : 1; + } else if (sa) { + range_c1 = np ? 2 : 1; + } + if (div && !main_div) { + range_d1 = sext ? 2 : 1; + } else if (sa) { + range_d1 = nr ? 2 : 1; + } + } else { + if (sa) { + range_a3 = na ? 2 : 1; + if (div) { + range_c3 = np ? 2 : 1; + range_d3 = nr ? 2 : 1; + } else { + range_d3 = np ? 2 : 1; + } + } + if (sb) { + range_b3 = nb ? 2 : 1; + } + } + const int flags = m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * sext + + 128 * div_by_zero + 256 * div_overflow + 512 * main_mul + + 1024 * main_div + 2048 * signed; + + int range_ab = (range_a3 + range_a1) * 3 + range_b3 + range_b1; + if ((range_a1 + range_b1) > 0) { + range_ab = range_ab + 8; + } + int range_cd = (range_c3 + range_c1) * 3 + range_d3 + range_d1; + if ((range_c1 + range_d1) > 0) { + range_cd = range_cd + 8; + } + // const int range_cd = range_c3 * 3 + range_d3 + m32 * 8 + range_c1 * 3 + range_d1; + + OP[index] = opcode; + FLAGS[index] = flags; + RANGE_AB[index] = range_ab; + RANGE_CD[index] = range_cd; + + if (generate_table) { + println(`OP:${opcode} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} m32:${m32} div:${div}`, + `div_by_zero:${div_by_zero} div_overflow:${div_overflow} sa:${sa} sb:${sb} main_mul:${main_mul}`, + `main_div:${main_div} signed:${signed} range_ab:${range_ab} range_cd:${range_cd} index:${(opcode - 0xb0) * 128 + icase} icase:${icase}`); + + op2row[(opcode - 0xb0) * 128 + icase] = index; + code = code + `[${opcode}, ${flags}, ${range_ab}, ${range_cd}],`; + } + ++index; + } + } + const int size = index; + + println("ARITH_TABLE SIZE: ", size); + assert(size < 256); + + if (generate_table) { + println(`pub const ROWS: usize = ${size};`); + println("const __: u8 = 255;"); + string _op2row = ""; + for (int i = 0; i < 2048; ++i) { + _op2row = _op2row + ((op2row[i] == 255) ? "__":string(op2row[i])) + ","; + } + println("pub static ARITH_TABLE_ROWS: [u8; 2048] = [", _op2row, "];"); + println(`pub static ARITH_TABLE: [[u16; 4]; ROWS] = [${code}];`); + } + + // padding repeat first row + + const int padding_op = OP[0]; + const int padding_flags = FLAGS[0]; + const int padding_range_ab = RANGE_AB[0]; + const int padding_range_cd = RANGE_CD[0]; + + for (index = size; index < N; ++index) { + OP[index] = padding_op; + FLAGS[index] = padding_flags; + RANGE_AB[index] = padding_range_ab; + RANGE_CD[index] = padding_range_cd; + } + col witness multiplicity; + + lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS, RANGE_AB, RANGE_CD]); +} + +function arith_table_assumes( const expr op, const expr flag_m32, const expr flag_div, const expr flag_na, + const expr flag_nb, const expr flag_np, const expr flag_nr, const expr flag_sext, + const expr flag_div_by_zero, const expr flag_div_overflow, + const expr flag_main_mul, const expr flag_main_div, const expr flag_signed, + const expr range_ab, const expr range_cd) { + + lookup_assumes(ARITH_TABLE_ID, cols: [ op, flag_m32 + 2 * flag_div + 4 * flag_na + 8 * flag_nb + + 16 * flag_np + 32 * flag_nr + 64 * flag_sext + + 128 * flag_div_by_zero + 256 * flag_div_overflow + + 512 * flag_main_mul + 1024 * flag_main_div + 2048 * flag_signed, + range_ab, range_cd]); +} diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index 3284da2d..2e2a4d5e 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -1,157 +1,68 @@ use std::sync::{ atomic::{AtomicU32, Ordering}, - Arc, Mutex, + Arc, }; -use crate::{Arith3264SM, Arith32SM, Arith64SM}; use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; -use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; +use zisk_core::ZiskRequiredOperation; +use zisk_pil::{ARITH_AIR_IDS, ARITH_RANGE_TABLE_AIR_IDS, ARITH_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; -const PROVE_CHUNK_SIZE: usize = 1 << 12; +use crate::{ArithFullSM, ArithRangeTableSM, ArithTableSM}; #[allow(dead_code)] -pub struct ArithSM { +pub struct ArithSM { // Count of registered predecessors registered_predecessors: AtomicU32, - // Inputs - inputs32: Mutex>, - inputs64: Mutex>, - - // Secondary State machines - arith32_sm: Arc, - arith64_sm: Arc, - arith3264_sm: Arc, + arith_full_sm: Arc>, + arith_table_sm: Arc>, + arith_range_table_sm: Arc>, } -impl ArithSM { - pub fn new(wcm: Arc>) -> Arc { - let arith32_sm = Arith32SM::new(wcm.clone()); - let arith64_sm = Arith64SM::new(wcm.clone()); - let arith3264_sm = Arith3264SM::new(wcm.clone()); - +impl ArithSM { + pub fn new(wcm: Arc>) -> Arc { + let arith_table_sm = ArithTableSM::new(wcm.clone(), ZISK_AIRGROUP_ID, ARITH_TABLE_AIR_IDS); + let arith_range_table_sm = + ArithRangeTableSM::new(wcm.clone(), ZISK_AIRGROUP_ID, ARITH_RANGE_TABLE_AIR_IDS); + let arith_full_sm = ArithFullSM::new( + wcm.clone(), + arith_table_sm.clone(), + arith_range_table_sm.clone(), + ZISK_AIRGROUP_ID, + ARITH_AIR_IDS, + ); let arith_sm = Self { registered_predecessors: AtomicU32::new(0), - inputs32: Mutex::new(Vec::new()), - inputs64: Mutex::new(Vec::new()), - arith32_sm, - arith64_sm, - arith3264_sm, + arith_full_sm, + arith_table_sm, + arith_range_table_sm, }; let arith_sm = Arc::new(arith_sm); wcm.register_component(arith_sm.clone(), None, None); - arith_sm.arith32_sm.register_predecessor(); - arith_sm.arith64_sm.register_predecessor(); - arith_sm.arith3264_sm.register_predecessor(); + arith_sm.arith_full_sm.register_predecessor(); arith_sm } - pub fn register_predecessor(&self) { self.registered_predecessors.fetch_add(1, Ordering::SeqCst); } - pub fn unregister_predecessor(&self, scope: &Scope) { + pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - - self.arith3264_sm.unregister_predecessor::(scope); - self.arith64_sm.unregister_predecessor::(scope); - self.arith32_sm.unregister_predecessor::(scope); + self.arith_full_sm.unregister_predecessor(); } } -} - -impl WitnessComponent for ArithSM { - fn calculate_witness( + pub fn prove_instance( &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + operations: Vec, + prover_buffer: &mut [F], + offset: u64, ) { + self.arith_full_sm.prove_instance(operations, prover_buffer, offset); } } -impl Provable for ArithSM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); - Ok(result) - } - - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - let mut _inputs32 = Vec::new(); - let mut _inputs64 = Vec::new(); - - let operations32 = Arith32SM::operations(); - let operations64 = Arith64SM::operations(); - - // TODO Split the operations into 32 and 64 bit operations in parallel - for operation in operations { - if operations32.contains(&operation.opcode) { - _inputs32.push(operation.clone()); - } else if operations64.contains(&operation.opcode) { - _inputs64.push(operation.clone()); - } else { - panic!("ArithSM: Operator {:x} not found", operation.opcode); - } - } - - // TODO When drain is true, drain remaining inputs to the 3264 bits state machine - - let mut inputs32 = self.inputs32.lock().unwrap(); - inputs32.extend(_inputs32); - - while inputs32.len() >= PROVE_CHUNK_SIZE || (drain && !inputs32.is_empty()) { - if drain && !inputs32.is_empty() { - // println!("ArithSM: Draining inputs32"); - } - - let num_drained32 = std::cmp::min(PROVE_CHUNK_SIZE, inputs32.len()); - let drained_inputs32 = inputs32.drain(..num_drained32).collect::>(); - let arith32_sm_cloned = self.arith32_sm.clone(); - - arith32_sm_cloned.prove(&drained_inputs32, drain, scope); - } - drop(inputs32); - - let mut inputs64 = self.inputs64.lock().unwrap(); - inputs64.extend(_inputs64); - - while inputs64.len() >= PROVE_CHUNK_SIZE || (drain && !inputs64.is_empty()) { - if drain && !inputs64.is_empty() { - // println!("ArithSM: Draining inputs64"); - } - - let num_drained64 = std::cmp::min(PROVE_CHUNK_SIZE, inputs64.len()); - let drained_inputs64 = inputs64.drain(..num_drained64).collect::>(); - let arith64_sm_cloned = self.arith64_sm.clone(); - - arith64_sm_cloned.prove(&drained_inputs64, drain, scope); - } - drop(inputs64); - } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} +impl WitnessComponent for ArithSM {} diff --git a/state-machines/arith/src/arith_32.rs b/state-machines/arith/src/arith_32.rs deleted file mode 100644 index d4241b56..00000000 --- a/state-machines/arith/src/arith_32.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::{ - fmt::Error, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, - }, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; -use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::{ARITH32_AIR_IDS, ARITH_AIRGROUP_ID}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct Arith32SM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -impl Arith32SM { - pub fn new(wcm: Arc>) -> Arc { - let arith32_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let arith32_sm = Arc::new(arith32_sm); - - wcm.register_component(arith32_sm.clone(), Some(ARITH_AIRGROUP_ID), Some(ARITH32_AIR_IDS)); - - arith32_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - pub fn operations() -> Vec { - vec![0xb6, 0xb7, 0xbe, 0xbf] - } -} - -impl WitnessComponent for Arith32SM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, - ) { - } -} - -impl Provable for Arith32SM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute( - ZiskOp::try_from_code(operation.opcode).map_err(|_| Error)?.code(), - operation.a, - operation.b, - ); - Ok(result) - } - - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/arith/src/arith_3264.rs b/state-machines/arith/src/arith_3264.rs deleted file mode 100644 index 81e289bb..00000000 --- a/state-machines/arith/src/arith_3264.rs +++ /dev/null @@ -1,107 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; -use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; - -use p3_field::Field; -use zisk_pil::{ARITH3264_AIR_IDS, ARITH_AIRGROUP_ID}; -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct Arith3264SM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -impl Arith3264SM { - pub fn new(wcm: Arc>) -> Arc { - let arith3264_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let arith3264_sm = Arc::new(arith3264_sm); - - wcm.register_component( - arith3264_sm.clone(), - Some(ARITH_AIRGROUP_ID), - Some(ARITH3264_AIR_IDS), - ); - - arith3264_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove( - self, - &[], - true, - scope, - ); - } - } -} - -impl WitnessComponent for Arith3264SM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, - ) { - } -} - -impl Provable for Arith3264SM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); - Ok(result) - } - - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - if drain && !inputs.is_empty() { - // println!("Arith3264SM: Draining inputs3264"); - } - - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/arith/src/arith_64.rs b/state-machines/arith/src/arith_64.rs deleted file mode 100644 index 7660ad12..00000000 --- a/state-machines/arith/src/arith_64.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; -use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::{ARITH64_AIR_IDS, ARITH_AIRGROUP_ID}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct Arith64SM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -impl Arith64SM { - pub fn new(wcm: Arc>) -> Arc { - let arith64_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let arith64_sm = Arc::new(arith64_sm); - - wcm.register_component(arith64_sm.clone(), Some(ARITH_AIRGROUP_ID), Some(ARITH64_AIR_IDS)); - - arith64_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - pub fn operations() -> Vec { - vec![0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb8, 0xb9, 0xba, 0xbb] - } -} - -impl WitnessComponent for Arith64SM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, - ) { - } -} - -impl Provable for Arith64SM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); - Ok(result) - } - - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/arith/src/arith_constants.rs b/state-machines/arith/src/arith_constants.rs new file mode 100644 index 00000000..4a7af91a --- /dev/null +++ b/state-machines/arith/src/arith_constants.rs @@ -0,0 +1,14 @@ +pub const MULU: u8 = 0xb0; +pub const MULUH: u8 = 0xb1; +pub const MULSUH: u8 = 0xb3; +pub const MUL: u8 = 0xb4; +pub const MULH: u8 = 0xb5; +pub const MUL_W: u8 = 0xb6; +pub const DIVU: u8 = 0xb8; +pub const REMU: u8 = 0xb9; +pub const DIV: u8 = 0xba; +pub const REM: u8 = 0xbb; +pub const DIVU_W: u8 = 0xbc; +pub const REMU_W: u8 = 0xbd; +pub const DIV_W: u8 = 0xbe; +pub const REM_W: u8 = 0xbf; diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs new file mode 100644 index 00000000..96590ba3 --- /dev/null +++ b/state-machines/arith/src/arith_full.rs @@ -0,0 +1,234 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, +}; + +use crate::{ + arith_constants::*, ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithTableInputs, + ArithTableSM, +}; +use log::info; +use p3_field::Field; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; +use sm_common::i64_to_u64_field; +use zisk_core::ZiskRequiredOperation; +use zisk_pil::*; + +pub struct ArithFullSM { + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Inputs + arith_table_sm: Arc>, + arith_range_table_sm: Arc>, +} + +impl ArithFullSM { + const MY_NAME: &'static str = "Arith "; + pub fn new( + wcm: Arc>, + arith_table_sm: Arc>, + arith_range_table_sm: Arc>, + airgroup_id: usize, + air_ids: &[usize], + ) -> Arc { + let arith_full_sm = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + arith_table_sm, + arith_range_table_sm, + }; + let arith_full_sm = Arc::new(arith_full_sm); + + wcm.register_component(arith_full_sm.clone(), Some(airgroup_id), Some(air_ids)); + + arith_full_sm.arith_table_sm.register_predecessor(); + arith_full_sm.arith_range_table_sm.register_predecessor(); + + arith_full_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.arith_table_sm.unregister_predecessor(); + self.arith_range_table_sm.unregister_predecessor(); + } + } + pub fn prove_instance( + &self, + input: Vec, + prover_buffer: &mut [F], + offset: u64, + ) { + let mut range_table_inputs = ArithRangeTableInputs::new(); + let mut table_inputs = ArithTableInputs::new(); + + let pctx = self.wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, ARITH_AIR_IDS[0]); + let num_rows = air.num_rows(); + timer_start_trace!(ARITH_TRACE); + info!( + "{}: ··· Creating Arith instance KKKKK [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + input.len(), + num_rows, + input.len() as f64 / num_rows as f64 * 100.0 + ); + assert!(input.len() <= num_rows); + + let mut traces = + ArithTrace::::map_buffer(prover_buffer, num_rows, offset as usize).unwrap(); + + let mut aop = ArithOperation::new(); + for (irow, input) in input.iter().enumerate() { + println!("#{} ARITH op:0x{:X} a:0x{:X} b:0x{:X}", irow, input.opcode, input.a, input.b); + aop.calculate(input.opcode, input.a, input.b); + let mut t: ArithRow = Default::default(); + for i in [0, 2] { + t.a[i] = F::from_canonical_u64(aop.a[i]); + t.b[i] = F::from_canonical_u64(aop.b[i]); + t.c[i] = F::from_canonical_u64(aop.c[i]); + t.d[i] = F::from_canonical_u64(aop.d[i]); + range_table_inputs.use_chunk_range_check(0, aop.a[i]); + range_table_inputs.use_chunk_range_check(0, aop.b[i]); + range_table_inputs.use_chunk_range_check(0, aop.c[i]); + range_table_inputs.use_chunk_range_check(0, aop.d[i]); + } + for i in [1, 3] { + t.a[i] = F::from_canonical_u64(aop.a[i]); + t.b[i] = F::from_canonical_u64(aop.b[i]); + t.c[i] = F::from_canonical_u64(aop.c[i]); + t.d[i] = F::from_canonical_u64(aop.d[i]); + } + range_table_inputs.use_chunk_range_check(aop.range_ab, aop.a[3]); + range_table_inputs.use_chunk_range_check(aop.range_ab + 26, aop.a[1]); + range_table_inputs.use_chunk_range_check(aop.range_ab + 17, aop.b[3]); + range_table_inputs.use_chunk_range_check(aop.range_ab + 9, aop.b[1]); + + range_table_inputs.use_chunk_range_check(aop.range_cd, aop.c[3]); + range_table_inputs.use_chunk_range_check(aop.range_cd + 26, aop.c[1]); + range_table_inputs.use_chunk_range_check(aop.range_cd + 17, aop.d[3]); + range_table_inputs.use_chunk_range_check(aop.range_cd + 9, aop.d[1]); + + for i in 0..7 { + t.carry[i] = F::from_canonical_u64(i64_to_u64_field(aop.carry[i])); + range_table_inputs.use_carry_range_check(aop.carry[i]); + } + t.op = F::from_canonical_u8(aop.op); + t.m32 = F::from_bool(aop.m32); + t.div = F::from_bool(aop.div); + t.na = F::from_bool(aop.na); + t.nb = F::from_bool(aop.nb); + t.np = F::from_bool(aop.np); + t.nr = F::from_bool(aop.nr); + t.signed = F::from_bool(aop.signed); + t.main_mul = F::from_bool(aop.main_mul); + t.main_div = F::from_bool(aop.main_div); + t.sext = F::from_bool(aop.sext); + t.multiplicity = F::one(); + t.debug_main_step = F::from_canonical_u64(input.step); + t.range_ab = F::from_canonical_u8(aop.range_ab); + t.range_cd = F::from_canonical_u8(aop.range_cd); + t.div_by_zero = F::from_bool(aop.div_by_zero); + t.div_overflow = F::from_bool(aop.div_overflow); + t.inv_sum_all_bs = if aop.div && !aop.div_by_zero { + F::from_canonical_u64(aop.b[0] + aop.b[1] + aop.b[2] + aop.b[3]).inverse() + } else { + F::zero() + }; + + table_inputs.add_use( + aop.op, + aop.na, + aop.nb, + aop.np, + aop.nr, + aop.sext, + aop.div_by_zero, + aop.div_overflow, + ); + + t.fab = if aop.na != aop.nb { F::neg_one() } else { F::one() }; + // na * (1 - 2 * nb); + t.na_fb = if aop.na { + if aop.nb { + F::neg_one() + } else { + F::one() + } + } else { + F::zero() + }; + t.nb_fa = if aop.nb { + if aop.na { + F::neg_one() + } else { + F::one() + } + } else { + F::zero() + }; + t.bus_res1 = F::from_canonical_u64(if aop.sext { + 0xFFFFFFFF + } else if aop.m32 { + 0 + } else if aop.main_mul { + aop.c[2] + (aop.c[3] << 16) + } else if aop.main_div { + aop.a[2] + (aop.a[3] << 16) + } else { + aop.d[2] + (aop.d[3] << 16) + }); + traces[irow] = t; + } + timer_stop_and_log_trace!(ARITH_TRACE); + + timer_start_trace!(ARITH_PADDING); + let padding_offset = input.len(); + let padding_rows: usize = + if num_rows > padding_offset { num_rows - padding_offset } else { 0 }; + + if padding_rows > 0 { + let mut t: ArithRow = Default::default(); + let padding_opcode = MULUH; + t.op = F::from_canonical_u8(padding_opcode); + t.fab = F::one(); + for i in padding_offset..num_rows { + traces[i] = t; + } + range_table_inputs.multi_use_chunk_range_check(padding_rows * 10, 0, 0); + range_table_inputs.multi_use_chunk_range_check(padding_rows * 2, 26, 0); + range_table_inputs.multi_use_chunk_range_check(padding_rows * 2, 17, 0); + range_table_inputs.multi_use_chunk_range_check(padding_rows * 2, 9, 0); + range_table_inputs.multi_use_carry_range_check(padding_rows * 7, 0); + table_inputs.multi_add_use( + padding_rows, + padding_opcode, + false, + false, + false, + false, + false, + false, + false, + ); + } + timer_stop_and_log_trace!(ARITH_PADDING); + timer_start_trace!(ARITH_TABLE); + info!("{}: ··· calling arit_table_sm", Self::MY_NAME); + self.arith_table_sm.process_slice(&table_inputs); + timer_stop_and_log_trace!(ARITH_TABLE); + timer_start_trace!(ARITH_RANGE_TABLE); + self.arith_range_table_sm.process_slice(&range_table_inputs); + timer_stop_and_log_trace!(ARITH_RANGE_TABLE); + } +} + +impl WitnessComponent for ArithFullSM {} diff --git a/state-machines/arith/src/arith_operation.rs b/state-machines/arith/src/arith_operation.rs new file mode 100644 index 00000000..a4eecdf2 --- /dev/null +++ b/state-machines/arith/src/arith_operation.rs @@ -0,0 +1,683 @@ +use crate::{arith_constants::*, arith_range_table_helpers::*}; +use std::fmt; + +pub struct ArithOperation { + pub op: u8, + pub input_a: u64, + pub input_b: u64, + pub a: [u64; 4], + pub b: [u64; 4], + pub c: [u64; 4], + pub d: [u64; 4], + pub carry: [i64; 7], + pub m32: bool, + pub div: bool, + pub na: bool, + pub nb: bool, + pub np: bool, + pub nr: bool, + pub sext: bool, + pub main_mul: bool, + pub main_div: bool, + pub signed: bool, + pub range_ab: u8, + pub range_cd: u8, + pub div_by_zero: bool, + pub div_overflow: bool, +} + +impl Default for ArithOperation { + fn default() -> Self { + Self::new() + } +} +impl fmt::Debug for ArithOperation { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut flags = String::new(); + if self.m32 { + flags += "m32 " + }; + if self.div { + flags += "div " + }; + if self.na { + flags += "na " + }; + if self.nb { + flags += "nb " + }; + if self.np { + flags += "np " + }; + if self.nr { + flags += "nr " + }; + if self.sext { + flags += "sext " + }; + if self.div_by_zero { + flags += "div_by_zero " + }; + if self.div_overflow { + flags += "div_overflow " + }; + if self.main_mul { + flags += "main_mul " + }; + if self.main_div { + flags += "main_div " + }; + if self.signed { + flags += "signed " + }; + writeln!(f, "operation 0x{:x} flags={}", self.op, flags)?; + writeln!(f, "input_a: 0x{0:x}({0})", self.input_a)?; + writeln!(f, "input_b: 0x{0:x}({0})", self.input_b)?; + self.dump_chunks(f, "a", &self.a)?; + self.dump_chunks(f, "b", &self.b)?; + self.dump_chunks(f, "c", &self.c)?; + self.dump_chunks(f, "d", &self.d)?; + writeln!( + f, + "carry: [0x{0:X}({0}), 0x{1:X}({1}), 0x{2:X}({2}), 0x{3:X}({3}), 0x{4:X}({4}), 0x{5:X}({5}), 0x{6:X}({6})]", + self.carry[0], self.carry[1], self.carry[2], self.carry[3], self.carry[4], self.carry[5], self.carry[6] + )?; + writeln!( + f, + "range_ab: 0x{0:X} {1}, range_cd:0x{2:X} {3}", + self.range_ab, + ArithRangeTableHelpers::get_range_name(self.range_ab), + self.range_cd, + ArithRangeTableHelpers::get_range_name(self.range_cd) + ) + } +} + +impl ArithOperation { + fn dump_chunks(&self, f: &mut fmt::Formatter, name: &str, value: &[u64; 4]) -> fmt::Result { + writeln!( + f, + "{0}: [0x{1:X}({1}), 0x{2:X}({2}), 0x{3:X}({3}), 0x{4:X}({4})]", + name, value[0], value[1], value[2], value[3] + ) + } + pub fn new() -> Self { + Self { + op: 0, + input_a: 0, + input_b: 0, + a: [0, 0, 0, 0], + b: [0, 0, 0, 0], + c: [0, 0, 0, 0], + d: [0, 0, 0, 0], + carry: [0, 0, 0, 0, 0, 0, 0], + m32: false, + div: false, + na: false, + nb: false, + np: false, + nr: false, + sext: false, + div_by_zero: false, + div_overflow: false, + main_mul: false, + main_div: false, + signed: false, + range_ab: 0, + range_cd: 0, + } + } + pub fn calculate(&mut self, op: u8, input_a: u64, input_b: u64) { + self.op = op; + self.input_a = input_a; + self.input_b = input_b; + self.div_by_zero = input_b == 0 && + (op == DIV || + op == REM || + op == DIV_W || + op == REM_W || + op == DIVU || + op == REMU || + op == DIVU_W || + op == REMU_W); + + self.div_overflow = ((op == DIV || op == REM) && + input_a == 0x8000_0000_0000_0000 && + input_b == 0xFFFF_FFFF_FFFF_FFFF) || + ((op == DIV_W || op == REM_W) && input_a == 0x8000_0000 && input_b == 0xFFFF_FFFF); + + let [a, b, c, d] = Self::calculate_abcd_from_ab(op, input_a, input_b); + self.a = Self::u64_to_chunks(a); + self.b = Self::u64_to_chunks(b); + self.c = Self::u64_to_chunks(c); + self.d = Self::u64_to_chunks(d); + self.update_flags_and_ranges(op, a, b, c, d); + let chunks = self.calculate_chunks(); + self.update_carries(&chunks); + } + fn update_carries(&mut self, chunks: &[i64; 8]) { + for (i, chunk) in chunks.iter().enumerate() { + let chunk_value = chunk + if i > 0 { self.carry[i - 1] } else { 0 }; + if i >= 7 { + continue; + } + self.carry[i] = chunk_value / 0x10000; + } + } + fn sign32(abs_value: u64, negative: bool) -> u64 { + assert!(0xFFFF_FFFF >= abs_value, "abs_value:0x{0:X}({0}) is too big", abs_value); + if negative { + (0xFFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + + fn sign64(abs_value: u64, negative: bool) -> u64 { + if negative { + (0xFFFF_FFFF_FFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + fn sign128(abs_value: u128, negative: bool) -> u128 { + if negative { + (0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + fn abs32(value: u64) -> [u64; 2] { + let negative = if (value & 0x8000_0000) != 0 { 1 } else { 0 }; + let abs_value = if negative == 1 { (0xFFFF_FFFF - value) + 1 } else { value }; + [abs_value, negative] + } + fn abs64(value: u64) -> [u64; 2] { + let negative = if (value & 0x8000_0000_0000_0000) != 0 { 1 } else { 0 }; + let abs_value = if negative == 1 { (0xFFFF_FFFF_FFFF_FFFF - value) + 1 } else { value }; + [abs_value, negative] + } + fn calculate_mul_w(a: u64, b: u64) -> u64 { + (a & 0xFFFF_FFFF) * (b & 0xFFFF_FFFF) + } + + fn calculate_mulsu(a: u64, b: u64) -> [u64; 2] { + let [abs_a, na] = Self::abs64(a); + let abs_c = abs_a as u128 * b as u128; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + let c = Self::sign128(abs_c, nc == 1); + [c as u64, (c >> 64) as u64] + } + + fn calculate_mul(a: u64, b: u64) -> [u64; 2] { + let [abs_a, na] = Self::abs64(a); + let [abs_b, nb] = Self::abs64(b); + let abs_c = abs_a as u128 * abs_b as u128; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + let c = Self::sign128(abs_c, nc == 1); + [c as u64, (c >> 64) as u64] + } + + fn calculate_div(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs64(a); + let [abs_b, nb] = Self::abs64(b); + if abs_b == 0 { + 0xFFFF_FFFF_FFFF_FFFF + } else { + let abs_c = abs_a / abs_b; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + Self::sign64(abs_c, nc == 1) + } + } + fn calculate_div_w(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs32(a); + let [abs_b, nb] = Self::abs32(b); + if abs_b == 0 { + 0xFFFF_FFFF + } else { + let abs_c = abs_a / abs_b; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + Self::sign32(abs_c, nc == 1) + } + } + + fn calculate_divu(a: u64, b: u64) -> u64 { + if b == 0 { + 0xFFFF_FFFF_FFFF_FFFF + } else { + a / b + } + } + + fn calculate_divu_w(a: u64, b: u64) -> u64 { + if b == 0 { + 0xFFFF_FFFF + } else { + a / b + } + } + + fn calculate_remu(a: u64, b: u64) -> u64 { + if b == 0 { + a + } else { + a % b + } + } + + fn calculate_remu_w(a: u64, b: u64) -> u64 { + if b == 0 { + a + } else { + a % b + } + } + + fn calculate_rem(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs64(a); + let [abs_b, _nb] = Self::abs64(b); + if abs_b == 0 { + a + } else { + let abs_c = abs_a % abs_b; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + Self::sign64(abs_c, nc == 1) + } + } + + fn calculate_rem_w(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs32(a); + let [abs_b, _nb] = Self::abs32(b); + if abs_b == 0 { + a + } else { + let abs_c = abs_a % abs_b; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + Self::sign32(abs_c, nc == 1) + } + } + + fn calculate_abcd_from_ab(op: u8, a: u64, b: u64) -> [u64; 4] { + match op { + MULU | MULUH => { + let c: u128 = a as u128 * b as u128; + [a, b, c as u64, (c >> 64) as u64] + } + MULSUH => { + let [c, d] = Self::calculate_mulsu(a, b); + [a, b, c, d] + } + MUL | MULH => { + let [c, d] = Self::calculate_mul(a, b); + [a, b, c, d] + } + MUL_W => [a, b, Self::calculate_mul_w(a, b), 0], + DIVU | REMU => [Self::calculate_divu(a, b), b, a, Self::calculate_remu(a, b)], + DIVU_W | REMU_W => [Self::calculate_divu_w(a, b), b, a, Self::calculate_remu_w(a, b)], + DIV | REM => [Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)], + DIV_W | REM_W => [Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)], + _ => { + panic!("Invalid opcode"); + } + } + } + fn update_flags_and_ranges(&mut self, op: u8, a: u64, b: u64, c: u64, d: u64) { + self.m32 = false; + self.div = false; + self.np = false; + self.nr = false; + self.sext = false; + self.main_mul = false; + self.main_div = false; + self.signed = false; + + let mut range_a1: u8 = 0; + let mut range_b1: u8 = 0; + let mut range_c1: u8 = 0; + let mut range_d1: u8 = 0; + let mut range_a3: u8 = 0; + let mut range_b3: u8 = 0; + let mut range_c3: u8 = 0; + let mut range_d3: u8 = 0; + + // direct table opcode(14), signed 2 or 4 cases (0,na,nb,na+nb) + // 6 * 1 + 7 * 4 + 1 * 2 = 36 entries, + // no compacted => 16 x 4 = 64, key = (op - 0xb0) * 4 + na * 2 + nb + // output: div, m32, sa, sb, nr, np, na, na32, nd32, range x 2 x 4 + + // alternative: switch operation, + + let mut sa = false; + let mut sb = false; + let mut rem = false; + + match op { + MULU => { + self.main_mul = true; + } + MULUH => {} + MULSUH => { + sa = true; + } + MUL => { + sa = true; + sb = true; + self.main_mul = true; + } + MULH => { + sa = true; + sb = true; + } + MUL_W => { + self.m32 = true; + self.sext = ((a * b) & 0xFFFF_FFFF) & 0x8000_0000 != 0; + self.main_mul = true; + } + DIVU => { + self.div = true; + self.main_div = true; + } + REMU => { + self.div = true; + rem = true; + } + DIV => { + sa = true; + sb = true; + self.div = true; + self.main_div = true; + } + REM => { + sa = true; + sb = true; + rem = true; + self.div = true; + } + DIVU_W => { + // divu_w, remu_w + self.div = true; + self.m32 = true; + // use a in bus + self.sext = (a & 0x8000_0000) != 0; + self.main_div = true; + } + REMU_W => { + // divu_w, remu_w + self.div = true; + self.m32 = true; + rem = true; + // use d in bus + self.sext = (d & 0x8000_0000) != 0; + } + DIV_W => { + // div_w, rem_w + sa = true; + sb = true; + self.div = true; + self.m32 = true; + // use a in bus + self.sext = (a & 0x8000_0000) != 0; + self.main_div = true; + } + REM_W => { + // div_w, rem_w + sa = true; + sb = true; + self.div = true; + self.m32 = true; + rem = true; + // use d in bus + self.sext = (d & 0x8000_0000) != 0; + } + _ => { + panic!("Invalid opcode"); + } + } + self.signed = sa || sb; + + let sign_mask: u64 = if self.m32 { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; + let sign_c_mask: u64 = + if self.m32 && self.div { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; + self.na = sa && (a & sign_mask) != 0; + self.nb = sb && (b & sign_mask) != 0; + // a sign => b sign + let nc = sa && (c & sign_c_mask) != 0; + let nd = sa && (d & sign_mask) != 0; + + // a == 0 || b == 0 => np == 0 ==> how was a signed operation + // after that sign of np was verified with range check. + // TODO: review if secure + if self.div { + self.np = nc; //if c != 0 { na ^ nb } else { 0 }; + self.nr = nd; + } else { + self.np = if self.m32 { nc } else { nd }; // if (c != 0) || (d != 0) { na ^ nb } else { 0 } + self.nr = false; + } + if self.m32 { + // mulw, divu_w, remu_w, div_w, rem_w + range_a1 = if sa { + if self.na { + 2 + } else { + 1 + } + } else if self.div && !rem { + if self.sext { + 2 + } else { + 1 + } + } else { + 0 + }; + range_b1 = if sb { + if self.nb { + 2 + } else { + 1 + } + } else { + 0 + }; + // m32 && div == 0 => mulw + range_c1 = if !self.div { + if self.sext { + 2 + } else { + 1 + } + } else if sa { + if self.np { + 2 + } else { + 1 + } + } else { + 0 + }; + range_d1 = if rem { + if self.sext { + 2 + } else { + 1 + } + } else if sa { + if self.nr { + 2 + } else { + 1 + } + } else { + 0 + }; + } else { + // mulu, muluh, mulsuh, mul, mulh, div, rem, divu, remu + if sa { + // mulsuh, mul, mulh, div, rem + range_a3 = if self.na { 2 } else { 1 }; + if self.div { + // div, rem + range_c3 = if self.np { 2 } else { 1 }; + range_d3 = if self.nr { 2 } else { 1 } + } else { + range_d3 = if self.np { 2 } else { 1 } + } + } + // sb => mul, mulh, div, rem + range_b3 = if sb { + if self.nb { + 2 + } else { + 1 + } + } else { + 0 + }; + } + + // range_ab / range_cd + // + // a3 a1 b3 b1 + // rid c3 c1 d3 d1 range 2^16 2^15 notes + // --- -- -- -- -- ----- ---- ---- ------------------------- + // 0 F F F F ab cd 4 0 + // 1 F F + F cd 3 1 b3 sign => a3 sign + // 2 F F - F cd 3 1 b3 sign => a3 sign + // 3 + F F F ab 3 1 c3 sign => d3 sign + // 4 + F + F ab cd 2 2 + // 5 + F - F ab cd 2 2 + // 6 - F F F ab 3 1 c3 sign => d3 sign + // 7 - F + F ab cd 2 2 + // 8 - F - F ab cd 2 2 + // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 11 F + F F cd 3 1 a1 sign <=> b1 sign + // 12 F + F + ab cd 2 2 + // 13 F + F - ab cd 2 2 + // 14 F - F F cd 3 1 a1 sign <=> b1 sign + // 15 F - F + ab cd 2 2 + // 16 F - F - ab cd 2 2 + + assert!(range_a1 == 0 || range_a3 == 0, "range_a1:{} range_a3:{}", range_a1, range_a3); + assert!(range_b1 == 0 || range_b3 == 0, "range_b1:{} range_b3:{}", range_b1, range_b3); + assert!(range_c1 == 0 || range_c3 == 0, "range_c1:{} range_c3:{}", range_c1, range_c3); + assert!(range_d1 == 0 || range_d3 == 0, "range_d1:{} range_d3:{}", range_d1, range_d3); + + self.range_ab = (range_a3 + range_a1) * 3 + + range_b3 + + range_b1 + + if (range_a1 + range_b1) > 0 { 8 } else { 0 }; + + self.range_cd = (range_c3 + range_c1) * 3 + + range_d3 + + range_d1 + + if (range_c1 + range_d1) > 0 { 8 } else { 0 }; + } + + pub fn calculate_chunks(&self) -> [i64; 8] { + // TODO: unroll this function in variants (div,m32) and (na,nb,nr,np) + // div, m32, na, nb === f(div,m32,na,nb) => fa, nb, nr + // unroll means 16 variants ==> but more performance + + let mut chunks: [i64; 8] = [0, 0, 0, 0, 0, 0, 0, 0]; + + let fab = if self.na != self.nb { -1 } else { 1 }; + + let a = [self.a[0] as i64, self.a[1] as i64, self.a[2] as i64, self.a[3] as i64]; + let b = [self.b[0] as i64, self.b[1] as i64, self.b[2] as i64, self.b[3] as i64]; + let c = [self.c[0] as i64, self.c[1] as i64, self.c[2] as i64, self.c[3] as i64]; + let d = [self.d[0] as i64, self.d[1] as i64, self.d[2] as i64, self.d[3] as i64]; + + let na = self.na as i64; + let nb = self.nb as i64; + let np = self.np as i64; + let nr = self.nr as i64; + let m32 = self.m32 as i64; + let div = self.div as i64; + + let na_fb = na * (1 - 2 * nb); + let nb_fa = nb * (1 - 2 * na); + + chunks[0] = fab * a[0] * b[0] // chunk0 + - c[0] + + 2 * np * c[0] + + div * d[0] - + 2 * nr * d[0]; + + chunks[1] = fab * a[1] * b[0] // chunk1 + + fab * a[0] * b[1] - + c[1] + + 2 * np * c[1] + + div * d[1] - + 2 * nr * d[1]; + + chunks[2] = fab * a[2] * b[0] // chunk2 + + fab * a[1] * b[1] + + fab * a[0] * b[2] + + a[0] * nb_fa * m32 + + b[0] * na_fb * m32 - + c[2] + + 2 * np * c[2] + + div * d[2] - + 2 * nr * d[2] - + np * div * m32 + + nr * m32; // div == 0 ==> nr = 0 + + chunks[3] = fab * a[3] * b[0] // chunk3 + + fab * a[2] * b[1] + + fab * a[1] * b[2] + + fab * a[0] * b[3] + + a[1] * nb_fa * m32 + + b[1] * na_fb * m32 - + c[3] + + 2 * np * c[3] + + div * d[3] - + 2 * nr * d[3]; + + chunks[4] = fab * a[3] * b[1] // chunk4 + + fab * a[2] * b[2] + + fab * a[1] * b[3] + + na * nb * m32 + // + b[0] * na * (1 - 2 * nb) + // + a[0] * nb * (1 - 2 * na) + + b[0] * na_fb * (1 - m32) + + a[0] * nb_fa * (1 - m32) + // high bits ^^^ + // - np * div + // + np * div * m32 + // - 2 * div * m32 * np + - np * m32 * (1 - div) // + - np * (1 - m32) * div // 2^64 (np) + + nr * (1 - m32) // 2^64 (nr) + // high part d + - d[0] * (1 - div) // m32 == 1 and div == 0 => d = 0 + + 2 * np * d[0] * (1 - div); // + + chunks[5] = fab * a[3] * b[2] // chunk5 + + fab * a[2] * b[3] + + a[1] * nb_fa * (1 - m32) + + b[1] * na_fb * (1 - m32) - + d[1] * (1 - div) + + d[1] * 2 * np * (1 - div); + + chunks[6] = fab * a[3] * b[3] // chunk6 + + a[2] * nb_fa * (1 - m32) + + b[2] * na_fb * (1 - m32) - + d[2] * (1 - div) + + d[2] * 2 * np * (1 - div); + + // 0x4000_0000_0000_0000__8000_0000_0000_0000 + chunks[7] = 0x10000 * na * nb * (1 - m32) // chunk7 + + a[3] * nb_fa * (1 - m32) + + b[3] * na_fb * (1 - m32) - + 0x10000 * np * (1 - div) * (1 - m32) - + d[3] * (1 - div) + + d[3] * 2 * np * (1 - div); + + chunks + } + fn u64_to_chunks(a: u64) -> [u64; 4] { + [a & 0xFFFF, (a >> 16) & 0xFFFF, (a >> 32) & 0xFFFF, (a >> 48) & 0xFFFF] + } +} diff --git a/state-machines/arith/src/arith_operation_test.rs b/state-machines/arith/src/arith_operation_test.rs new file mode 100644 index 00000000..75f1aa4a --- /dev/null +++ b/state-machines/arith/src/arith_operation_test.rs @@ -0,0 +1,254 @@ +use zisk_core::zisk_ops::*; + +use crate::{ + arith_constants::*, arith_table_data, ArithOperation, ArithRangeTableHelpers, ArithTableHelpers, +}; + +const MIN_N_64: u64 = 0x8000_0000_0000_0000; +const MIN_N_32: u64 = 0x0000_0000_8000_0000; +const MAX_P_64: u64 = 0x7FFF_FFFF_FFFF_FFFF; +const MAX_P_32: u64 = 0x0000_0000_7FFF_FFFF; +const MAX_32: u64 = 0x0000_0000_FFFF_FFFF; +const MAX_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; + +const ALL_VALUES: [u64; 16] = [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + MAX_64 - 1, + MIN_N_64, + MIN_N_64 + 1, + MAX_64, +]; + +const ALL_OPERATIONS: [u8; 14] = + [MUL, MULH, MULSUH, MULU, MULUH, DIVU, REMU, DIV, REM, MUL_W, DIVU_W, REMU_W, DIV_W, REM_W]; + +struct ArithOperationTest { + count: u32, + fail: u32, + fail_by_op: [u32; 16], + pending: u32, + table_rows: [u16; arith_table_data::ROWS], +} + +impl ArithOperationTest { + // NOTE: use 0x0000_0000 instead of 0, to avoid auto-format in one line, 0 is too short. + pub fn new() -> Self { + ArithOperationTest { + count: 0, + fail: 0, + fail_by_op: [0; 16], + pending: 0, + table_rows: [0; arith_table_data::ROWS], + } + } + fn test(&mut self) { + self.count = 0; + + for op in ALL_OPERATIONS { + let m32 = Self::is_m32_op(op); + for a in ALL_VALUES { + if m32 && a > 0xFFFF_FFFF { + continue; + } + for b in ALL_VALUES { + if m32 && b > 0xFFFF_FFFF { + continue; + } + println!("===> TEST CASE op:0x{:x} with a:0x{:X} b:0x{:X} <===", op, a, b); + let (emu_c, emu_flag) = Self::calculate_emulator_res(op, a, b); + self.test_operation(op, a, b, emu_c, emu_flag); + self.count += 1; + } + } + } + for index in 0..arith_table_data::ROWS { + if self.table_rows[index] == 0 { + println!( + "\x1B[31mTable row {0} not tested op:0x{1:x}({1}) flags:{2}\x1B[0m", + index, + arith_table_data::ARITH_TABLE[index][0], + ArithTableHelpers::flags_to_string(arith_table_data::ARITH_TABLE[index][1]), + ); + self.pending += 1; + } + } + println!("TOTAL TESTS:{} ERRORS: {}", self.count, self.fail); + } + + fn is_m32_op(op: u8) -> bool { + match op { + MUL | MULH | MULSUH | MULU | MULUH | DIVU | REMU | DIV | REM => false, + MUL_W | DIVU_W | REMU_W | DIV_W | REM_W => true, + _ => panic!("Invalid opcode"), + } + } + fn calculate_emulator_res(op: u8, a: u64, b: u64) -> (u64, bool) { + match op { + MULU => op_mulu(a, b), + MULUH => op_muluh(a, b), + MULSUH => op_mulsuh(a, b), + MUL => op_mul(a, b), + MULH => op_mulh(a, b), + MUL_W => op_mul_w(a, b), + DIVU => op_divu(a, b), + REMU => op_remu(a, b), + DIVU_W => op_divu_w(a, b), + REMU_W => op_remu_w(a, b), + DIV => op_div(a, b), + REM => op_rem(a, b), + DIV_W => op_div_w(a, b), + REM_W => op_rem_w(a, b), + _ => { + panic!("Invalid opcode"); + } + } + } + + fn test_operation(&mut self, op: u8, a: u64, b: u64, c: u64, flag: bool) { + let mut aop = ArithOperation::new(); + aop.calculate(op, a, b); + println!("testing op:0x{:x} a:0x{:X} b:0x{:X} c:0x{:X} flag:{}", op, a, b, c, flag); + let chunks = aop.calculate_chunks(); + for (i, chunk) in chunks.iter().enumerate() { + let carry_in = if i > 0 { aop.carry[i - 1] } else { 0 }; + let carry_out = if i < 7 { aop.carry[i] } else { 0 }; + let res = chunk + carry_in - 0x10000 * carry_out; + if res != 0 { + println!("{:#?}", aop); + + self.fail += 1; + self.fail_by_op[(op - 0xb0) as usize] += 1; + println!("\x1B[31mFAIL: 0x{4:X}({4})!= 0 chunks[{0}]=0x{1:X}({1}) carry_in: 0x{2:x},{2} carry_out: 0x{3:x},{3} failed\x1B[0m", + i, + chunk, + carry_in, + carry_out, + res); + } + } + println!("{:#?}", aop); + + const CHUNK_SIZE: u64 = 0x10000; + let bus_a_low: u64 = aop.div as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + let bus_a_high: u64 = aop.div as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + + let bus_b_low: u64 = aop.b[0] + CHUNK_SIZE * aop.b[1]; + let bus_b_high: u64 = aop.b[2] + CHUNK_SIZE * aop.b[3]; + + let secondary_res: u64 = if aop.main_mul || aop.main_div { 0 } else { 1 }; + + let bus_res_low = secondary_res * (aop.d[0] + aop.d[1] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + + let bus_res_high_64 = secondary_res * (aop.d[2] + aop.d[3] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + + let bus_res_high = if aop.sext && !aop.div_overflow { 0xFFFF_FFFF } else { 0 } + + (1 - aop.m32 as u64) * bus_res_high_64; + + let expected_a_low = a & 0xFFFF_FFFF; + let expected_a_high = (a >> 32) & 0xFFFF_FFFF; + let expected_b_low = b & 0xFFFF_FFFF; + let expected_b_high = (b >> 32) & 0xFFFF_FFFF; + let expected_res_low = c & 0xFFFF_FFFF; + let expected_res_high = (c >> 32) & 0xFFFF_FFFF; + + assert_eq!( + bus_a_low, expected_a_low, + "bus_a_low: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_a_low, expected_a_low + ); + assert_eq!( + bus_a_high, expected_a_high, + "bus_a_high: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_a_high, expected_a_high + ); + assert_eq!( + bus_b_low, expected_b_low, + "bus_b_low: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_b_low, expected_b_low + ); + assert_eq!( + bus_b_high, expected_b_high, + "bus_b_high: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_b_high, expected_b_high + ); + assert_eq!( + bus_res_low, expected_res_low, + "bus_c_low: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_res_low, expected_res_low + ); + assert_eq!( + bus_res_high, expected_res_high, + "bus_c_high: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_res_high, expected_res_high + ); + for i in 0..7 { + ArithRangeTableHelpers::get_row_carry_range_check(aop.carry[i]); + } + + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_ab, aop.a[3]); + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_ab + 26, aop.a[1]); + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_ab + 17, aop.b[3]); + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_ab + 9, aop.b[1]); + + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_cd, aop.c[3]); + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_cd + 26, aop.c[1]); + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_cd + 17, aop.d[3]); + ArithRangeTableHelpers::get_row_chunk_range_check(aop.range_cd + 9, aop.d[1]); + + for i in [0, 2] { + ArithRangeTableHelpers::get_row_chunk_range_check(0, aop.a[i]); + ArithRangeTableHelpers::get_row_chunk_range_check(0, aop.b[i]); + ArithRangeTableHelpers::get_row_chunk_range_check(0, aop.c[i]); + ArithRangeTableHelpers::get_row_chunk_range_check(0, aop.d[i]); + } + + let row_1 = ArithTableHelpers::get_row( + aop.op, + aop.na, + aop.nb, + aop.np, + aop.nr, + aop.sext, + aop.div_by_zero, + aop.div_overflow, + aop.m32, + aop.div, + aop.main_mul, + aop.main_div, + aop.signed, + aop.range_ab as u16, + aop.range_cd as u16, + ); + self.table_rows[row_1] += 1; + } +} + +#[test] +fn test() { + let mut test = ArithOperationTest::new(); + test.test(); + for i in 0..16 { + if test.fail_by_op[i] == 0 { + continue; + } + println!("fail_by_op[0x{:X}]: {}", i + 0xb0, test.fail_by_op[i]); + } + assert_eq!(test.fail, 0); +} diff --git a/state-machines/arith/src/arith_range_table.rs b/state-machines/arith/src/arith_range_table.rs new file mode 100644 index 00000000..4b3dcf1f --- /dev/null +++ b/state-machines/arith/src/arith_range_table.rs @@ -0,0 +1,118 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::ArithRangeTableInputs; +use log::info; +use p3_field::Field; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx}; +use rayon::prelude::*; +use sm_common::create_prover_buffer; +use zisk_pil::{ARITH_RANGE_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; + +pub struct ArithRangeTableSM { + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Inputs + num_rows: usize, + multiplicity: Mutex>, +} + +impl ArithRangeTableSM { + const MY_NAME: &'static str = "ArithRT "; + + pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, ARITH_RANGE_TABLE_AIR_IDS[0]); + let arith_range_table_sm = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + num_rows: air.num_rows(), + multiplicity: Mutex::new(vec![0; air.num_rows()]), + }; + let arith_range_table_sm = Arc::new(arith_range_table_sm); + + wcm.register_component(arith_range_table_sm.clone(), Some(airgroup_id), Some(air_ids)); + + arith_range_table_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.create_air_instance(); + } + } + pub fn process_slice(&self, inputs: &ArithRangeTableInputs) { + // Create the trace vector + let mut _multiplicity = self.multiplicity.lock().unwrap(); + + for (row, value) in inputs { + _multiplicity[row] += value; + } + } + pub fn create_air_instance(&self) { + let ectx = self.wcm.get_ectx(); + let mut dctx: std::sync::RwLockWriteGuard<'_, proofman_common::DistributionCtx> = + ectx.dctx.write().unwrap(); + let mut multiplicity = self.multiplicity.lock().unwrap(); + + let (is_myne, instance_global_idx) = + dctx.add_instance(ZISK_AIRGROUP_ID, ARITH_RANGE_TABLE_AIR_IDS[0], 1); + let owner: usize = dctx.owner(instance_global_idx); + + let mut multiplicity_ = std::mem::take(&mut *multiplicity); + dctx.distribute_multiplicity(&mut multiplicity_, owner); + + if is_myne { + // Create the prover buffer + let (mut prover_buffer, offset) = create_prover_buffer( + &self.wcm.get_ectx(), + &self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + ARITH_RANGE_TABLE_AIR_IDS[0], + ); + prover_buffer[offset as usize..offset as usize + self.num_rows] + .par_iter_mut() + .enumerate() + .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); + + info!( + "{}: ··· Creating Binary basic table instance [{} rows filled 100%]", + Self::MY_NAME, + self.num_rows, + ); + let air_instance = AirInstance::new( + self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + ARITH_RANGE_TABLE_AIR_IDS[0], + None, + prover_buffer, + ); + self.wcm + .get_pctx() + .air_instance_repo + .add_air_instance(air_instance, Some(instance_global_idx)); + } + } +} + +impl WitnessComponent for ArithRangeTableSM { + fn calculate_witness( + &self, + _stage: u32, + _air_instance: Option, + _pctx: Arc>, + _ectx: Arc>, + _sctx: Arc>, + ) { + } +} diff --git a/state-machines/arith/src/arith_range_table_helpers.rs b/state-machines/arith/src/arith_range_table_helpers.rs new file mode 100644 index 00000000..ab237198 --- /dev/null +++ b/state-machines/arith/src/arith_range_table_helpers.rs @@ -0,0 +1,234 @@ +use std::collections::HashMap; + +const ROWS: usize = 1 << 22; +const FULL: u8 = 0x00; +const POS: u8 = 0x01; +const NEG: u8 = 0x02; +pub struct ArithRangeTableHelpers; + +const RANGES: [u8; 43] = [ + FULL, FULL, FULL, POS, POS, POS, NEG, NEG, NEG, FULL, FULL, FULL, FULL, FULL, FULL, FULL, FULL, + FULL, POS, NEG, FULL, POS, NEG, FULL, POS, NEG, FULL, FULL, FULL, FULL, FULL, FULL, FULL, FULL, + FULL, FULL, FULL, POS, POS, POS, NEG, NEG, NEG, +]; +const OFFSETS: [usize; 43] = [ + 0, 2, 4, 50, 51, 52, 59, 60, 61, 6, 8, 10, 12, 14, 16, 18, 20, 22, 53, 62, 24, 54, 63, 26, 55, + 64, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 56, 57, 58, 65, 66, 67, +]; + +impl ArithRangeTableHelpers { + pub fn get_range_name(range_index: u8) -> &'static str { + match range_index { + 0 => "F F F F", + 1 => "F F + F", + 2 => "F F - F", + 3 => "+ F F F", + 4 => "+ F + F", + 5 => "+ F - F", + 6 => "- F F F", + 7 => "- F + F", + 8 => "- F - F", + 9 => "F F F +", + 10 => "F F F -", + 11 => "F + F F", + 12 => "F + F +", + 13 => "F + F -", + 14 => "F - F F", + 15 => "F - F +", + 16 => "F - F -", + _ => panic!("Invalid range index"), + } + } + pub fn get_row_chunk_range_check(range_index: u8, value: u64) -> usize { + // F F F + + + - - - F F F F F F F F F + - F + - F + - F F F F F F F F F F F + + + - - - + let range_type = RANGES[range_index as usize]; + assert!(range_index < 43); + assert!(value >= if range_type == NEG { 0x8000 } else { 0 }); + assert!( + value <= + match range_type { + FULL => 0xFFFF, + POS => 0x7FFF, + NEG => 0xFFFF, + _ => panic!("Invalid range type"), + } + ); + OFFSETS[range_index as usize] * 0x8000 + + if range_type == NEG { value - 0x8000 } else { value } as usize + } + pub fn get_row_carry_range_check(value: i64) -> usize { + assert!(value >= -0xEFFFF); + assert!(value <= 0xF0000); + (0x220000 + 0xEFFFF + value) as usize + } +} +pub struct ArithRangeTableInputs { + // TODO: check improvement of multiplicity[64] to reserv only chunks used + // with this 16 bits version, this table has aprox 8MB. + updated: u64, + multiplicity_overflow: HashMap, + multiplicity: Vec, +} + +impl Default for ArithRangeTableInputs { + fn default() -> Self { + Self::new() + } +} + +impl ArithRangeTableInputs { + pub fn new() -> Self { + ArithRangeTableInputs { + updated: 0, + multiplicity_overflow: HashMap::new(), + multiplicity: vec![0u16; ROWS], + } + } + fn incr_row_one(&mut self, row: usize) { + if self.multiplicity[row] > u16::MAX - 1 { + let count = self.multiplicity_overflow.entry(row as u32).or_insert(0); + *count += 1; + self.multiplicity[row] = 0; + } else { + self.multiplicity[row] += 1; + } + self.updated &= 1 << (row >> (22 - 6)); + } + fn incr_row(&mut self, row: usize, times: usize) { + self.incr_row_without_update(row, times); + self.updated &= 1 << (row >> (22 - 6)); + } + fn incr_row_without_update(&mut self, row: usize, times: usize) { + if (u16::MAX - self.multiplicity[row]) as usize <= times { + let count = self.multiplicity_overflow.entry(row as u32).or_insert(0); + let new_count = self.multiplicity[row] as u64 + times as u64; + *count += (new_count >> 16) as u32; + self.multiplicity[row] = (new_count & 0xFFFF) as u16; + } else { + self.multiplicity[row] += times as u16; + } + } + pub fn use_chunk_range_check(&mut self, range_id: u8, value: u64) { + let row = ArithRangeTableHelpers::get_row_chunk_range_check(range_id, value); + self.incr_row_one(row); + } + pub fn use_carry_range_check(&mut self, value: i64) { + let row = ArithRangeTableHelpers::get_row_carry_range_check(value); + self.incr_row_one(row); + } + pub fn multi_use_chunk_range_check(&mut self, times: usize, range_id: u8, value: u64) { + let row = ArithRangeTableHelpers::get_row_chunk_range_check(range_id, value); + self.incr_row(row, times); + } + pub fn multi_use_carry_range_check(&mut self, times: usize, value: i64) { + let row = ArithRangeTableHelpers::get_row_carry_range_check(value); + self.incr_row(row, times); + } + pub fn update_with(&mut self, other: &Self) { + let chunk_size = 1 << (22 - 6); + for i_chunk in 0..64 { + if (other.updated & (1 << i_chunk)) == 0 { + continue; + } + let from = chunk_size * i_chunk; + let to = from + chunk_size; + for row in from..to { + let count = other.multiplicity[row]; + if count > 0 { + self.incr_row_without_update(row, count as usize); + } + } + } + for (row, value) in other.multiplicity_overflow.iter() { + let count = self.multiplicity_overflow.entry(*row).or_insert(0); + *count += (*value) << 16; + } + self.updated |= other.updated; + } +} + +pub struct ArithRangeTableInputsIterator<'a> { + iter_row: u32, + iter_hash: bool, + inputs: &'a ArithRangeTableInputs, +} + +impl<'a> Iterator for ArithRangeTableInputsIterator<'a> { + type Item = (usize, u64); + + fn next(&mut self) -> Option { + if !self.iter_hash { + while self.iter_row < ROWS as u32 && + self.inputs.multiplicity[self.iter_row as usize] == 0 + { + self.iter_row += 1; + } + let row = self.iter_row as usize; + if row < ROWS { + self.iter_row += 1; + return Some((row, self.inputs.multiplicity[row] as u64)); + } + self.iter_hash = true; + self.iter_row = 0; + } + let res = self.inputs.multiplicity_overflow.iter().nth(self.iter_row as usize); + match res { + Some((row, value)) => { + self.iter_row += 1; + Some((*row as usize, (*value as u64) << 16)) + } + None => None, + } + } +} + +impl<'a> IntoIterator for &'a ArithRangeTableInputs { + type Item = (usize, u64); + type IntoIter = ArithRangeTableInputsIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + ArithRangeTableInputsIterator { iter_row: 0, iter_hash: false, inputs: self } + } +} + +#[cfg(feature = "generate_code_arith_range_table")] +#[allow(dead_code)] +fn generate_table() { + let pattern = "FFF+++---FFFFFFFFF+-F+-F+-FFFFFFFFFFF+++---"; + // let mut ranges = [0u8; 43]; + let mut ranges = String::new(); + let mut offsets = [0usize; 43]; + let mut offset = 0; + for range_loop in [FULL, POS, NEG] { + let mut index = 0; + for c in pattern.chars() { + if c == ' ' || c == '_' { + continue; + } + let range_id = match c { + 'F' => FULL, + '+' => POS, + '-' => NEG, + _ => panic!("Invalid character in pattern"), + }; + if range_loop == FULL { + if index > 0 { + ranges.push_str(", ") + } + ranges.push_str(match range_id { + FULL => "FULL", + POS => "POS", + _ => "NEG", + }); + // ranges[index] = range_id + } + if range_loop == range_id { + offsets[index] = offset; + offset += if range_loop == FULL { 2 } else { 1 }; + } + index += 1; + } + } + println!("const RANGES: [u8; 43] = [{}];", ranges); + println!("const OFFSETS: [usize; 43] = {:?};", offsets); +} diff --git a/state-machines/arith/src/arith_table.rs b/state-machines/arith/src/arith_table.rs new file mode 100644 index 00000000..6805f407 --- /dev/null +++ b/state-machines/arith/src/arith_table.rs @@ -0,0 +1,120 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::ArithTableInputs; +use log::info; +use p3_field::Field; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx}; +use rayon::prelude::*; +use sm_common::create_prover_buffer; +use zisk_pil::{ARITH_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; + +pub struct ArithTableSM { + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Inputs + num_rows: usize, + multiplicity: Mutex>, +} + +impl ArithTableSM { + const MY_NAME: &'static str = "ArithT "; + + pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, ARITH_TABLE_AIR_IDS[0]); + let _arith_table_sm = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + num_rows: air.num_rows(), + multiplicity: Mutex::new(vec![0; air.num_rows()]), + }; + let arith_table_sm = Arc::new(_arith_table_sm); + + wcm.register_component(arith_table_sm.clone(), Some(airgroup_id), Some(air_ids)); + + arith_table_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.create_air_instance(); + } + } + pub fn process_slice(&self, inputs: &ArithTableInputs) { + // Create the trace vector + let mut _multiplicity = self.multiplicity.lock().unwrap(); + + info!("{}: ··· process multiplicity", Self::MY_NAME); + for (row, value) in inputs { + info!("{}: ··· Processing row {} with value {}", Self::MY_NAME, row, value); + _multiplicity[row] += value; + } + } + pub fn create_air_instance(&self) { + let ectx = self.wcm.get_ectx(); + let mut dctx: std::sync::RwLockWriteGuard<'_, proofman_common::DistributionCtx> = + ectx.dctx.write().unwrap(); + let mut multiplicity = self.multiplicity.lock().unwrap(); + + let (is_myne, instance_global_idx) = + dctx.add_instance(ZISK_AIRGROUP_ID, ARITH_TABLE_AIR_IDS[0], 1); + let owner: usize = dctx.owner(instance_global_idx); + + let mut multiplicity_ = std::mem::take(&mut *multiplicity); + dctx.distribute_multiplicity(&mut multiplicity_, owner); + + if is_myne { + // Create the prover buffer + let (mut prover_buffer, offset) = create_prover_buffer( + &self.wcm.get_ectx(), + &self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + ARITH_TABLE_AIR_IDS[0], + ); + prover_buffer[offset as usize..offset as usize + self.num_rows] + .par_iter_mut() + .enumerate() + .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); + + info!( + "{}: ··· Creating Binary basic table instance [{} rows filled 100%]", + Self::MY_NAME, + self.num_rows, + ); + let air_instance = AirInstance::new( + self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + ARITH_TABLE_AIR_IDS[0], + None, + prover_buffer, + ); + self.wcm + .get_pctx() + .air_instance_repo + .add_air_instance(air_instance, Some(instance_global_idx)); + } + } +} + +impl WitnessComponent for ArithTableSM { + fn calculate_witness( + &self, + _stage: u32, + _air_instance: Option, + _pctx: Arc>, + _ectx: Arc>, + _sctx: Arc>, + ) { + } +} diff --git a/state-machines/arith/src/arith_table_data.rs b/state-machines/arith/src/arith_table_data.rs new file mode 100644 index 00000000..ce2335e3 --- /dev/null +++ b/state-machines/arith/src/arith_table_data.rs @@ -0,0 +1,167 @@ +pub const FIRST_OP: u8 = 0xb0; +pub const ROWS: usize = 74; +const __: u8 = 255; +pub static ARITH_TABLE_ROWS: [u8; 2048] = [ + 0, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, 1, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + 2, 3, __, __, __, 4, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, 5, 6, 7, 8, __, 9, 10, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 11, 12, 13, 14, __, 15, 16, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 17, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 18, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 19, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 20, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 21, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, 22, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, 23, __, 24, 25, __, 26, 27, __, __, __, __, __, 28, 29, 30, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 31, __, __, __, __, __, __, __, + __, __, __, __, 32, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, 33, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 34, __, 35, 36, __, 37, 38, __, __, + __, __, __, 39, 40, 41, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + 42, __, __, __, __, __, __, __, __, __, __, __, 43, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, 44, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 45, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 46, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 47, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, 48, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 49, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 50, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, 51, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 52, __, 53, __, __, __, 54, __, __, + __, __, __, 55, __, 56, __, __, __, __, 57, __, 58, __, __, __, __, __, __, __, 59, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 60, __, __, __, __, __, __, __, + __, __, __, __, 61, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, 62, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, 63, + __, 64, 65, __, 66, 67, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, 68, 69, 70, __, __, 71, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, 72, __, __, __, __, __, __, __, __, __, 73, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, + __, __, __, __, __, __, __, +]; +pub static ARITH_TABLE: [[u16; 4]; ROWS] = [ + [176, 512, 0, 0], + [177, 0, 0, 0], + [179, 2048, 3, 1], + [179, 2052, 6, 1], + [179, 2068, 6, 2], + [180, 2560, 4, 1], + [180, 2564, 7, 1], + [180, 2568, 5, 1], + [180, 2572, 8, 1], + [180, 2580, 7, 2], + [180, 2584, 5, 2], + [181, 2048, 4, 1], + [181, 2052, 7, 1], + [181, 2056, 5, 1], + [181, 2060, 8, 1], + [181, 2068, 7, 2], + [181, 2072, 5, 2], + [182, 513, 0, 11], + [182, 577, 0, 14], + [184, 1026, 0, 0], + [184, 1154, 0, 0], + [185, 2, 0, 0], + [185, 130, 0, 0], + [186, 3074, 4, 4], + [186, 3082, 5, 4], + [186, 3086, 8, 4], + [186, 3094, 7, 7], + [186, 3098, 5, 7], + [186, 3122, 4, 8], + [186, 3126, 7, 8], + [186, 3130, 5, 8], + [186, 3206, 7, 4], + [186, 3254, 7, 8], + [186, 3358, 8, 7], + [187, 2050, 4, 4], + [187, 2058, 5, 4], + [187, 2062, 8, 4], + [187, 2070, 7, 7], + [187, 2074, 5, 7], + [187, 2098, 4, 8], + [187, 2102, 7, 8], + [187, 2106, 5, 8], + [187, 2182, 7, 4], + [187, 2230, 7, 8], + [187, 2334, 8, 7], + [188, 1027, 11, 0], + [188, 1091, 14, 0], + [188, 1219, 14, 0], + [189, 3, 0, 9], + [189, 67, 0, 10], + [189, 131, 0, 9], + [189, 195, 0, 10], + [190, 3075, 12, 12], + [190, 3083, 13, 12], + [190, 3099, 13, 15], + [190, 3123, 12, 16], + [190, 3131, 13, 16], + [190, 3151, 16, 12], + [190, 3159, 15, 15], + [190, 3191, 15, 16], + [190, 3271, 15, 12], + [190, 3319, 15, 16], + [190, 3423, 16, 15], + [191, 2051, 12, 12], + [191, 2059, 13, 12], + [191, 2063, 16, 12], + [191, 2071, 15, 15], + [191, 2075, 13, 15], + [191, 2163, 12, 16], + [191, 2167, 15, 16], + [191, 2171, 13, 16], + [191, 2183, 15, 12], + [191, 2295, 15, 16], + [191, 2335, 16, 15], +]; diff --git a/state-machines/arith/src/arith_table_helpers.rs b/state-machines/arith/src/arith_table_helpers.rs new file mode 100644 index 00000000..820c5361 --- /dev/null +++ b/state-machines/arith/src/arith_table_helpers.rs @@ -0,0 +1,244 @@ +pub struct ArithTableHelpers; + +#[cfg(debug_assertions)] +use crate::ARITH_TABLE; + +use crate::{ARITH_TABLE_ROWS, FIRST_OP, ROWS}; + +impl ArithTableHelpers { + #[allow(clippy::too_many_arguments)] + pub fn direct_get_row( + op: u8, + na: bool, + nb: bool, + np: bool, + nr: bool, + sext: bool, + div_by_zero: bool, + div_overflow: bool, + ) -> usize { + let index = (op - FIRST_OP) as u64 * 128 + + na as u64 + + nb as u64 * 2 + + np as u64 * 4 + + nr as u64 * 8 + + sext as u64 * 16 + + div_by_zero as u64 * 32 + + div_overflow as u64 * 64; + assert!(index < ARITH_TABLE_ROWS.len() as u64); + let row = ARITH_TABLE_ROWS[index as usize]; + assert!( + row < 255, + "INVALID ROW row:{} op:0x{:x} na:{} nb:{} np:{} nr:{} sext:{} div_by_zero:{} div_overflow:{} index:{}", + row, + op, + na as u8, + nb as u8, + np as u8, + nr as u8, + sext as u8, + div_by_zero as u8, + div_overflow as u8, + index + ); + row as usize + } + #[cfg(not(debug_assertions))] + pub fn get_row( + op: u8, + na: bool, + nb: bool, + np: bool, + nr: bool, + sext: bool, + div_by_zero: bool, + div_overflow: bool, + ) -> usize { + Self::direct_get_row(op, na, nb, np, nr, sext, div_by_zero, div_overflow) + } + #[cfg(debug_assertions)] + #[allow(clippy::too_many_arguments)] + pub fn get_row( + op: u8, + na: bool, + nb: bool, + np: bool, + nr: bool, + sext: bool, + div_by_zero: bool, + div_overflow: bool, + m32: bool, + div: bool, + main_mul: bool, + main_div: bool, + signed: bool, + range_ab: u16, + range_cd: u16, + ) -> usize { + let flags = if m32 { 1 } else { 0 } + + if div { 2 } else { 0 } + + if na { 4 } else { 0 } + + if nb { 8 } else { 0 } + + if np { 16 } else { 0 } + + if nr { 32 } else { 0 } + + if sext { 64 } else { 0 } + + if div_by_zero { 128 } else { 0 } + + if div_overflow { 256 } else { 0 } + + if main_mul { 512 } else { 0 } + + if main_div { 1024 } else { 0 } + + if signed { 2048 } else { 0 }; + let row = Self::direct_get_row(op, na, nb, np, nr, sext, div_by_zero, div_overflow); + assert_eq!( + op as u16, ARITH_TABLE[row][0], + "at row {} not match op {} vs {}", + row, op, ARITH_TABLE[row][0] + ); + assert_eq!( + flags, ARITH_TABLE[row][1], + "at row {0} op:0x{1:x}({1}) not match flags {2:b}({2}) vs {3:b}({3})", + row, op, flags, ARITH_TABLE[row][1] + ); + assert_eq!( + range_ab, ARITH_TABLE[row][2], + "at row {} op:{} not match range_ab {} vs {}", + row, op, flags, ARITH_TABLE[row][2] + ); + assert_eq!( + range_cd, ARITH_TABLE[row][3], + "at row {} op:{} not match range_cd {} vs {}", + row, op, flags, ARITH_TABLE[row][3] + ); + row + } + + pub fn flags_to_string(flags: u16) -> String { + let mut result = String::new(); + if flags & 1 != 0 { + result += " m32"; + } + if flags & 2 != 0 { + result += " div"; + } + if flags & 4 != 0 { + result += " na"; + } + if flags & 8 != 0 { + result += " nb"; + } + if flags & 16 != 0 { + result += " np"; + } + if flags & 32 != 0 { + result += " nr"; + } + if flags & 64 != 0 { + result += " sext"; + } + if flags & 128 != 0 { + result += " div_by_zero"; + } + if flags & 256 != 0 { + result += " div_overflow"; + } + if flags & 512 != 0 { + result += " main_mul"; + } + if flags & 1024 != 0 { + result += " main_div"; + } + if flags & 2048 != 0 { + result += " signed"; + } + result + } + + pub fn get_max_row() -> usize { + ROWS - 1 + } +} + +pub struct ArithTableInputs { + multiplicity: [u64; ROWS], +} + +impl Default for ArithTableInputs { + fn default() -> Self { + Self::new() + } +} + +impl ArithTableInputs { + pub fn new() -> Self { + ArithTableInputs { multiplicity: [0; ROWS] } + } + #[allow(clippy::too_many_arguments)] + pub fn add_use( + &mut self, + op: u8, + na: bool, + nb: bool, + np: bool, + nr: bool, + sext: bool, + div_by_zero: bool, + div_overflow: bool, + ) { + let row = + ArithTableHelpers::direct_get_row(op, na, nb, np, nr, sext, div_by_zero, div_overflow); + assert!(row < ROWS); + self.multiplicity[row] += 1; + } + #[allow(clippy::too_many_arguments)] + pub fn multi_add_use( + &mut self, + times: usize, + op: u8, + na: bool, + nb: bool, + np: bool, + nr: bool, + sext: bool, + div_by_zero: bool, + div_overflow: bool, + ) { + let row = + ArithTableHelpers::direct_get_row(op, na, nb, np, nr, sext, div_by_zero, div_overflow); + self.multiplicity[row] += times as u64; + } + pub fn update_with(&mut self, other: &Self) { + for i in 0..ROWS { + self.multiplicity[i] += other.multiplicity[i]; + } + } +} + +pub struct ArithTableInputsIterator<'a> { + iter_row: u32, + inputs: &'a ArithTableInputs, +} + +impl<'a> Iterator for ArithTableInputsIterator<'a> { + type Item = (usize, u64); + + fn next(&mut self) -> Option { + while self.iter_row < ROWS as u32 && self.inputs.multiplicity[self.iter_row as usize] == 0 { + self.iter_row += 1; + } + let row = self.iter_row as usize; + if row < ROWS { + self.iter_row += 1; + Some((row, self.inputs.multiplicity[row])) + } else { + None + } + } +} + +impl<'a> IntoIterator for &'a ArithTableInputs { + type Item = (usize, u64); + type IntoIter = ArithTableInputsIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + ArithTableInputsIterator { iter_row: 0, inputs: self } + } +} diff --git a/state-machines/arith/src/arith_traces.rs b/state-machines/arith/src/arith_traces.rs deleted file mode 100644 index 6552f02d..00000000 --- a/state-machines/arith/src/arith_traces.rs +++ /dev/null @@ -1,6 +0,0 @@ -use proofman_common as common; -pub use proofman_macros::trace; - -trace!(Arith32Row, Arith32Trace { fake: F }); -trace!(Arith64Row, Arith64Trace { fake: F }); -trace!(Arith3264Row, Arith3264Trace { fake: F }); diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index 8297735d..d7250835 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -1,11 +1,22 @@ mod arith; -mod arith_32; -mod arith_3264; -mod arith_64; -mod arith_traces; +mod arith_constants; +mod arith_full; +mod arith_operation; +mod arith_range_table; +mod arith_range_table_helpers; +mod arith_table; +mod arith_table_data; +mod arith_table_helpers; + +#[cfg(test)] +mod arith_operation_test; pub use arith::*; -pub use arith_32::*; -pub use arith_3264::*; -pub use arith_64::*; -pub use arith_traces::*; +pub use arith_constants::*; +pub use arith_full::*; +pub use arith_operation::*; +pub use arith_range_table::*; +pub use arith_range_table_helpers::*; +pub use arith_table::*; +pub use arith_table_data::*; +pub use arith_table_helpers::*; diff --git a/state-machines/binary/pil/binary_extension.pil b/state-machines/binary/pil/binary_extension.pil index 2e1587f4..c26594d8 100644 --- a/state-machines/binary/pil/binary_extension.pil +++ b/state-machines/binary/pil/binary_extension.pil @@ -44,7 +44,7 @@ x in1[x] out[x][0] out[x][1] 1 0x22 0x00220000 0x00000000 2 0x33 0x33000000 0x00000000 3 0x44 0x00000000 0x00000044 -4 0x55 0x00000000 0x00000000 (since 0x44 & 0x80 = 0, we stop here and set the remaining bytes to 0x00) +4 0x55 0x00000000 0x00000000 (since 0x44 & 0x80 = 0, we stop here and set the remaining bytes to 0x00) 5 0x66 0x00000000 0x00000000 (bytes of in1 are ignored from here) 6 0x77 0x00000000 0x00000000 7 0x88 0x00000000 0x00000000 @@ -72,7 +72,7 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BI const int bits = 64; const int bytes = bits / 8; - col witness op; + col witness op; col witness in1[bytes]; col witness in2_low; // Note: if in2_low∊[0,2^5-1], else in2_low∊[0,2^6-1] (checked by the table) col witness out[bytes][2]; @@ -108,4 +108,4 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BI ); range_check(colu: in2[0], min: 0, max: 2**24-1, sel: op_is_shift); -} \ No newline at end of file +} diff --git a/state-machines/common/src/field.rs b/state-machines/common/src/field.rs new file mode 100644 index 00000000..55d2c919 --- /dev/null +++ b/state-machines/common/src/field.rs @@ -0,0 +1,8 @@ +pub fn i64_to_u64_field(value: i64) -> u64 { + const PRIME_MINUS_ONE: u64 = 0xFFFF_FFFF_0000_0000; + if value >= 0 { + value as u64 + } else { + PRIME_MINUS_ONE - (0xFFFF_FFFF_FFFF_FFFF - value as u64) + } +} diff --git a/state-machines/common/src/lib.rs b/state-machines/common/src/lib.rs index 4f1f27e9..bb6f10ee 100644 --- a/state-machines/common/src/lib.rs +++ b/state-machines/common/src/lib.rs @@ -1,9 +1,11 @@ +mod field; mod operations; mod provable; mod session; mod temp; mod worker; +pub use field::*; pub use operations::*; use proofman_common::{ExecutionCtx, SetupCtx}; use proofman_util::create_buffer_fast; diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index 3c8c0d6c..30240f83 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -14,7 +14,8 @@ use proofman::WitnessComponent; use sm_arith::ArithSM; use sm_mem::MemSM; use zisk_pil::{ - MainRow, MainTrace, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, + MainRow, MainTrace, ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, + ZISK_AIRGROUP_ID, }; use ziskemu::{Emu, EmuTrace, ZiskEmulator}; @@ -28,7 +29,7 @@ pub struct MainSM { wcm: Arc>, /// Arithmetic state machine - arith_sm: Arc, + arith_sm: Arc>, /// Binary state machine binary_sm: Arc>, @@ -53,7 +54,7 @@ impl MainSM { /// * Arc to the MainSM state machine pub fn new( wcm: Arc>, - arith_sm: Arc, + arith_sm: Arc>, binary_sm: Arc>, mem_sm: Arc, ) -> Arc { @@ -151,7 +152,6 @@ impl MainSM { { partial_trace[i] = emu.step_slice_full_trace(emu_trace_step); } - // if there are steps in the chunk update last row if slice_end - slice_start > 0 { last_row = partial_trace[slice_end - slice_start - 1]; @@ -189,6 +189,42 @@ impl MainSM { iectx.air_instance = Some(air_instance); } + pub fn prove_arith( + &self, + zisk_rom: &ZiskRom, + vec_traces: &[EmuTrace], + iectx: &mut InstanceExtensionCtx, + pctx: &ProofCtx, + ) { + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, ARITH_AIR_IDS[0]); + + timer_start_debug!(PROCESS_ARITH); + let inputs = ZiskEmulator::process_slice_required::( + zisk_rom, + vec_traces, + iectx.op_type, + &iectx.emu_trace_start, + air.num_rows(), + ); + timer_stop_and_log_debug!(PROCESS_ARITH); + + timer_start_debug!(PROVE_ARITH); + + self.arith_sm.prove_instance(inputs, &mut iectx.prover_buffer, iectx.offset); + timer_stop_and_log_debug!(PROVE_ARITH); + + timer_start_debug!(CREATE_AIR_INSTANCE); + let buffer = std::mem::take(&mut iectx.prover_buffer); + iectx.air_instance = Some(AirInstance::new( + self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + ARITH_AIR_IDS[0], + None, + buffer, + )); + timer_stop_and_log_debug!(CREATE_AIR_INSTANCE); + } + pub fn prove_binary( &self, zisk_rom: &ZiskRom, diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs new file mode 100644 index 00000000..4e177ee3 --- /dev/null +++ b/state-machines/mem/src/mem_constants.rs @@ -0,0 +1,12 @@ +pub const MEM_ADDR_MASK: u64 = 0xFFFF_FFFF_FFFF_FFF8; +pub const MEM_BYTES: u64 = 8; + +pub const MAX_MEM_STEP_OFFSET: u64 = 2; +pub const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; + +pub const MEM_STEP_BITS: u64 = 34; // with step_slot = 8 => 2GB steps ( +pub const MEM_STEP_MASK: u64 = (1 << MEM_STEP_BITS) - 1; // 256 MB +pub const MEM_ADDR_BITS: u64 = 64 - MEM_STEP_BITS; + +pub const MAX_MEM_STEP: u64 = (1 << MEM_STEP_BITS) - 1; +pub const MAX_MEM_ADDR: u64 = (1 << MEM_ADDR_BITS) - 1; diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs new file mode 100644 index 00000000..ac4ca198 --- /dev/null +++ b/state-machines/mem/src/mem_helpers.rs @@ -0,0 +1,65 @@ +use crate::MemAlignResponse; +use std::fmt; +use zisk_core::ZiskRequiredMemory; + +fn format_u64_hex(value: u64) -> String { + let hex_str = format!("{:016x}", value); + hex_str + .as_bytes() + .chunks(4) + .map(|chunk| std::str::from_utf8(chunk).unwrap()) + .collect::>() + .join("_") +} + +impl fmt::Debug for MemAlignResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "more:{0} step:{1} value:{2:016X}({2:})", + self.more_address, + self.step, + self.value.unwrap_or(0) + ) + } +} + +pub fn mem_align_call( + mem_op: &ZiskRequiredMemory, + mem_values: [u64; 2], + phase: u8, +) -> MemAlignResponse { + // DEBUG: only for testing + let offset = (mem_op.address & 0x7) * 8; + let width = (mem_op.width as u64) * 8; + let double_address = (offset + width as u32) > 64; + let mem_value = mem_values[phase as usize]; + let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); + if mem_op.is_write { + if phase == 0 { + MemAlignResponse { + more_address: double_address, + step: mem_op.step + 1, + value: Some( + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) + | ((mem_op.value & mask) << offset), + ), + } + } else { + MemAlignResponse { + more_address: false, + step: mem_op.step + 1, + value: Some( + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) + | ((mem_op.value & mask) >> (128 - (offset + width as u32))), + ), + } + } + } else { + MemAlignResponse { + more_address: double_address && phase == 0, + step: mem_op.step + 1, + value: None, + } + } +} diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 0552b3d0..b77f6bfe 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -18,7 +18,9 @@ use std::{ sync::Arc, }; use zisk_core::{Riscv2zisk, ZiskOperationType, ZiskRom, ZISK_OPERATION_TYPE_VARIANTS}; -use zisk_pil::{BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{ + ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, +}; use ziskemu::{EmuOptions, ZiskEmulator}; pub struct ZiskExecutor { @@ -38,7 +40,7 @@ pub struct ZiskExecutor { pub binary_sm: Arc>, /// Arithmetic State Machine - pub arith_sm: Arc, + pub arith_sm: Arc>, } impl ZiskExecutor { @@ -125,12 +127,14 @@ impl ZiskExecutor { // machine. We aim to track the starting point of execution for every N instructions // across different operation types. Currently, we are only collecting data for // Binary and BinaryE operations. + let air_arith = pctx.pilout.get_air(ZISK_AIRGROUP_ID, ARITH_AIR_IDS[0]); let air_binary = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]); let air_binary_e = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); let mut op_sizes = [0u64; ZISK_OPERATION_TYPE_VARIANTS]; // The starting points for the Main is allocated using None operation op_sizes[ZiskOperationType::None as usize] = air_main.num_rows() as u64; + op_sizes[ZiskOperationType::Arith as usize] = air_arith.num_rows() as u64; op_sizes[ZiskOperationType::Binary as usize] = air_binary.num_rows() as u64; op_sizes[ZiskOperationType::BinaryE as usize] = air_binary_e.num_rows() as u64; @@ -175,6 +179,7 @@ impl ZiskExecutor { for emu_slice in emu_slices.points.iter() { let (airgroup_id, air_id) = match emu_slice.op_type { ZiskOperationType::None => (ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]), + ZiskOperationType::Arith => (ZISK_AIRGROUP_ID, ARITH_AIR_IDS[0]), ZiskOperationType::Binary => (ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]), ZiskOperationType::BinaryE => (ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]), _ => panic!("Invalid operation type"), @@ -206,6 +211,9 @@ impl ZiskExecutor { ZiskOperationType::None => { self.main_sm.prove_main(&self.zisk_rom, &emu_traces, iectx, &pctx); } + ZiskOperationType::Arith => { + self.main_sm.prove_arith(&self.zisk_rom, &emu_traces, iectx, &pctx); + } ZiskOperationType::Binary => { self.main_sm.prove_binary(&self.zisk_rom, &emu_traces, iectx, &pctx); } @@ -230,6 +238,6 @@ impl ZiskExecutor { // self.mem_sm.unregister_predecessor(scope); self.binary_sm.unregister_predecessor(); - // self.arith_sm.register_predecessor(scope); + self.arith_sm.unregister_predecessor(); } }