From 8885ae8621bdb218c534c9bbc26ee914893d9e66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip=20Ardevol?= Date: Wed, 11 Dec 2024 12:23:20 +0100 Subject: [PATCH 1/6] Update proofman (#186) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Feature/custom commits (#166) * Custom cols rom (#159) Custom cols working --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> * Cached custom commits * Updating proofman to 0.0.12 * Cargo fmt * Cargo fmt * Fix cargo clippy * Rom trace is now deterministic * cargo fmt * Global constraints verifying again * Optimizing the binary component (#167) * Optimizing the binary * Updating the executor * Updating to 0.0.13 * Not creating unnecessary instances of arith tables * Pil2-proofman 0.0.14 --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Co-authored-by: Héctor Masip Ardevol * Removing unnecessary code * Zisk working with last proofman version * Updating book and Cargo.toml to point to 0.0.16 proofman * cargo fmt --------- Co-authored-by: Roger Taulé Buxadera <55488871+RogerTaule@users.noreply.github.com> Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Co-authored-by: RogerTaule --- Cargo.lock | 227 +++++---- Cargo.toml | 14 +- book/getting_started/quickstart.md | 10 +- book/getting_started/quickstart_dev.md | 13 +- pil/constants.pil | 58 --- pil/src/pil_helpers/traces.rs | 8 +- pil/zisk.pil | 2 - rom-merkle/Cargo.toml | 2 + rom-merkle/src/main.rs | 53 +- state-machines/arith/src/arith.rs | 9 +- state-machines/arith/src/arith_full.rs | 11 +- state-machines/arith/src/arith_range_table.rs | 29 +- state-machines/arith/src/arith_table.rs | 29 +- state-machines/binary/pil/binary.pil | 90 ++-- .../binary/pil/binary_extension.pil | 5 +- .../binary/pil/binary_extension_table.pil | 2 +- state-machines/binary/pil/binary_table.pil | 5 +- state-machines/binary/src/binary.rs | 5 +- state-machines/binary/src/binary_basic.rs | 64 ++- .../binary/src/binary_basic_table.rs | 14 +- state-machines/binary/src/binary_extension.rs | 23 +- .../binary/src/binary_extension_table.rs | 15 +- state-machines/common/src/lib.rs | 20 - state-machines/freq-ops/src/freq_ops.rs | 4 +- state-machines/main/pil/main.pil | 5 +- state-machines/main/src/instance_extension.rs | 14 +- state-machines/main/src/main_sm.rs | 54 +- state-machines/mem/src/mem.rs | 4 +- state-machines/mem/src/mem_aligned.rs | 4 +- state-machines/mem/src/mem_unaligned.rs | 4 +- state-machines/publics.json | 6 + state-machines/quick-ops/src/quick_ops.rs | 4 +- state-machines/rom/pil/rom.pil | 25 +- state-machines/rom/src/rom.rs | 469 +++++------------- witness-computation/src/executor.rs | 8 +- witness-computation/src/zisk_lib.rs | 17 +- 36 files changed, 505 insertions(+), 821 deletions(-) delete mode 100644 pil/constants.pil create mode 100644 state-machines/publics.json diff --git a/Cargo.lock b/Cargo.lock index 326ad3e7..84d55e9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,9 +96,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.93" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" dependencies = [ "backtrace", ] @@ -168,9 +168,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cargo-zisk" @@ -198,9 +198,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" dependencies = [ "jobserver", "libc", @@ -248,9 +248,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.21" +version = "4.5.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +checksum = "69371e34337c4c984bbe322360c2547210bf632eb2814bbe78a6e87a2935bd2b" dependencies = [ "clap_builder", "clap_derive", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.21" +version = "4.5.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +checksum = "6e24c1b4099818523236a8ca881d2b45db98dadfb4625cf6608c12069fcbbde1" dependencies = [ "anstream", "anstyle", @@ -496,12 +496,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -682,9 +682,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "heck" @@ -692,12 +692,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - [[package]] name = "hermit-abi" version = "0.4.0" @@ -706,9 +700,9 @@ checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" [[package]] name = "http" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" dependencies = [ "bytes", "fnv", @@ -947,9 +941,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", "hashbrown", @@ -999,7 +993,7 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" dependencies = [ - "hermit-abi 0.4.0", + "hermit-abi", "libc", "windows-sys 0.52.0", ] @@ -1039,9 +1033,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.12" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a73e9fe3c49d7afb2ace819fa181a287ce54a0983eda4e0eb05c22f82ffe534" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "jobserver" @@ -1054,10 +1048,11 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.72" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" dependencies = [ + "once_cell", "wasm-bindgen", ] @@ -1075,9 +1070,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.164" +version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" +checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" [[package]] name = "libgit2-sys" @@ -1093,9 +1088,9 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", "windows-targets 0.52.6", @@ -1131,9 +1126,9 @@ checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "litemap" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "lock_api" @@ -1183,11 +1178,10 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -1467,7 +1461,7 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "log", "num-bigint", @@ -1485,7 +1479,7 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "bytes", "log", @@ -1542,9 +1536,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" [[package]] name = "powerfmt" @@ -1595,9 +1589,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.89" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -1605,7 +1599,7 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "colored", "env_logger", @@ -1626,11 +1620,12 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "env_logger", "log", "p3-field", + "p3-goldilocks", "pilout", "proofman-macros", "proofman-starks-lib-c", @@ -1644,7 +1639,7 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "p3-field", "proofman-common", @@ -1654,7 +1649,7 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "proc-macro2", "quote", @@ -1664,7 +1659,7 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "log", ] @@ -1672,7 +1667,7 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "colored", "sysinfo 0.31.4", @@ -1753,7 +1748,7 @@ dependencies = [ "rustc-hash", "rustls", "socket2", - "thiserror 2.0.3", + "thiserror 2.0.4", "tokio", "tracing", ] @@ -1772,7 +1767,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.3", + "thiserror 2.0.4", "tinyvec", "tracing", "web-time", @@ -1989,7 +1984,8 @@ dependencies = [ "proofman-common", "sm-rom", "stark", - "sysinfo 0.32.0", + "sysinfo 0.32.1", + "zisk-pil", ] [[package]] @@ -2000,9 +1996,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustc-hash" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" +checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" [[package]] name = "rustix" @@ -2019,9 +2015,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.17" +version = "0.23.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" +checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" dependencies = [ "once_cell", "ring", @@ -2293,9 +2289,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", @@ -2316,7 +2312,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "log", "p3-field", @@ -2357,9 +2353,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "symbolic-common" -version = "12.12.1" +version = "12.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d4d73159efebfb389d819fd479afb2dbd57dcb3e3f4b7fcfa0e675f5a46c1cb" +checksum = "e5ba5365997a4e375660bed52f5b42766475d5bc8ceb1bb13fea09c469ea0f49" dependencies = [ "debugid", "memmap2", @@ -2369,9 +2365,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.12.1" +version = "12.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a767859f6549c665011970874c3f541838b4835d5aaaa493d3ee383918be9f10" +checksum = "beff338b2788519120f38c59ff4bb15174f52a183e547bac3d6072c2c0aa48aa" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -2380,9 +2376,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.87" +version = "2.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" dependencies = [ "proc-macro2", "quote", @@ -2391,9 +2387,9 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] @@ -2425,9 +2421,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.32.0" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3b5ae3f4f7d64646c46c4cae4e3f01d1c5d255c7406fdd7c7f999a94e488791" +checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" dependencies = [ "core-foundation-sys", "libc", @@ -2467,11 +2463,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.3" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +checksum = "2f49a1853cf82743e3b7950f77e0f4d622ca36cf4317cba00c767838bac8d490" dependencies = [ - "thiserror-impl 2.0.3", + "thiserror-impl 2.0.4", ] [[package]] @@ -2487,9 +2483,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.3" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +checksum = "8381894bb3efe0c4acac3ded651301ceee58a15d47c2e34885ed1908ad667061" dependencies = [ "proc-macro2", "quote", @@ -2498,9 +2494,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", "itoa", @@ -2521,9 +2517,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" dependencies = [ "num-conv", "time-core", @@ -2575,9 +2571,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.41.1" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" +checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" dependencies = [ "backtrace", "bytes", @@ -2634,9 +2630,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -2645,9 +2641,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -2656,9 +2652,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", ] @@ -2666,7 +2662,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" dependencies = [ "proofman-starks-lib-c", ] @@ -2713,9 +2709,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.3" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", @@ -2798,9 +2794,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" dependencies = [ "cfg-if", "once_cell", @@ -2809,9 +2805,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" dependencies = [ "bumpalo", "log", @@ -2824,21 +2820,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.45" +version = "0.4.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" +checksum = "9dfaf8f50e5f293737ee323940c7d8b08a66a95a419223d9f41610ca08b0833d" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2846,9 +2843,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" dependencies = [ "proc-macro2", "quote", @@ -2859,9 +2856,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.95" +version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" [[package]] name = "wasm-streams" @@ -2878,9 +2875,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.72" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +checksum = "a98bc3c33f0fe7e59ad7cd041b89034fa82a7c2d4365ca538dda6cdaf513863c" dependencies = [ "js-sys", "wasm-bindgen", @@ -2898,9 +2895,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" dependencies = [ "rustls-pki-types", ] @@ -3187,9 +3184,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" dependencies = [ "serde", "stable_deref_trait", @@ -3199,9 +3196,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", @@ -3232,18 +3229,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index b97f5cb2..9f456214 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,13 +26,13 @@ opt-level = 3 opt-level = 3 [workspace.dependencies] -proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -# Local development +proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } +proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } +proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } +proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } +pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } +stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } +#Local development # proofman-common = { path = "../pil2-proofman/common" } # proofman-macros = { path = "../pil2-proofman/macros" } # proofman-util = { path = "../pil2-proofman/util" } diff --git a/book/getting_started/quickstart.md b/book/getting_started/quickstart.md index 724a3587..0710b3b5 100644 --- a/book/getting_started/quickstart.md +++ b/book/getting_started/quickstart.md @@ -112,9 +112,8 @@ ziskup ```bash git clone https://github.com/0xPolygonHermez/zisk git clone -b develop https://github.com/0xPolygonHermez/pil2-compiler.git -git clone -b 0.0.10 https://github.com/0xPolygonHermez/pil2-proofman.git -git clone -b 0.0.10 https://github.com/0xPolygonHermez/pil2-proofman-js -git clone -b 0.0.10 https://github.com/0xPolygonHermez/pil2-stark-js +git clone -b 0.0.16 https://github.com/0xPolygonHermez/pil2-proofman.git +git clone -b 0.0.16 https://github.com/0xPolygonHermez/pil2-proofman-js ``` All following commands should be executed in the `zisk` folder. @@ -157,11 +156,10 @@ cargo build --release To generate the proof, the following command needs to be run. ```bash -(cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../hello_world/target/riscv64ima-polygon-ziskos-elf/release/sha_hasher -i ../hello_world/build/input.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -d -v -a) +(cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../hello_world/target/riscv64ima-polygon-ziskos-elf/release/sha_hasher -i ../hello_world/build/input.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -v -a) ``` ### Verify the Proof ```bash -(cd ../pil2-stark-js && npm i) -node ../pil2-stark-js/src/main_verifier.js -v build/provingKey/zisk/final/final.verkey.json -s build/provingKey/zisk/final/final.starkinfo.json -i build/provingKey/zisk/final/final.verifierinfo.json -o proofs/proofs/final_proof.json -b proofs/publics.json +node ../pil2-proofman-js/src/main_verify -k build/provingKey/ -p proofs -t vadcop_final ``` diff --git a/book/getting_started/quickstart_dev.md b/book/getting_started/quickstart_dev.md index dfa921e9..f3965d23 100644 --- a/book/getting_started/quickstart_dev.md +++ b/book/getting_started/quickstart_dev.md @@ -18,9 +18,8 @@ Run the following commands to clone the necessary repositories: ```bash git clone -b develop https://github.com/0xPolygonHermez/pil2-compiler.git git clone -b develop https://github.com/0xPolygonHermez/zisk.git -git clone -b 0.0.10 https://github.com/0xPolygonHermez/pil2-proofman.git -git clone -b 0.0.10 https://github.com/0xPolygonHermez/pil2-stark-js.git -git clone -b 0.0.10 https://github.com/0xPolygonHermez/pil2-proofman-js +git clone -b 0.0.16 https://github.com/0xPolygonHermez/pil2-proofman.git +git clone -b 0.0.16 https://github.com/0xPolygonHermez/pil2-proofman-js ``` ## Compile a Verifiable Rust Program @@ -166,13 +165,13 @@ To generate the aggregated proofs, add `-a` ```bash // Using input_one_segment.bin -(cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input_one_segment.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -d -a -v) +(cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input_one_segment.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -a -v) // Using input_two_segments.bin -(cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input_two_segments.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -d -a -v) +(cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input_two_segments.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -a -v) // Using input.bin -(cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -d -a -v) +(cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli prove --witness-lib ../zisk/target/release/libzisk_witness.so --rom ../zisk/emulator/benches/data/my.elf -i ../zisk/emulator/benches/data/input.bin --proving-key ../zisk/build/provingKey --output-dir ../zisk/proofs -a -v) ``` ### Verify the Proof @@ -184,5 +183,5 @@ node ../pil2-proofman-js/src/main_verify -k ./build/provingKey -p ./proofs If the aggregation proofs are being generated, can be verified with the following command: ```bash -node ../pil2-stark-js/src/main_verifier.js -v build/provingKey/zisk/final/final.verkey.json -s build/provingKey/zisk/final/final.starkinfo.json -i build/provingKey/zisk/final/final.verifierinfo.json -o proofs/proofs/final_proof.json -b proofs/publics.json +node ../pil2-proofman-js/src/main_verify -k ./build/provingKey/ -p ./proofs -t vadcop_final ``` \ No newline at end of file diff --git a/pil/constants.pil b/pil/constants.pil deleted file mode 100644 index b08eaa57..00000000 --- a/pil/constants.pil +++ /dev/null @@ -1,58 +0,0 @@ -const int P2_1 = 2**1; -const int P2_2 = 2**2; -const int P2_3 = 2**3; -const int P2_4 = 2**4; -const int P2_5 = 2**5; -const int P2_6 = 2**6; -const int P2_7 = 2**7; -const int P2_8 = 2**8; -const int P2_9 = 2**9; -const int P2_10 = 2**10; -const int P2_11 = 2**11; -const int P2_12 = 2**12; -const int P2_13 = 2**13; -const int P2_14 = 2**14; -const int P2_15 = 2**15; -const int P2_16 = 2**16; -const int P2_17 = 2**17; -const int P2_18 = 2**18; -const int P2_19 = 2**19; -const int P2_20 = 2**20; -const int P2_21 = 2**21; -const int P2_22 = 2**22; -const int P2_23 = 2**23; -const int P2_24 = 2**24; -const int P2_31 = 2**31; -const int P2_32 = 2**32; -const int P2_63 = 2**63; -const int P2_64 = 2**64; - - -const int MASK_1 = P2_1 - 1; -const int MASK_2 = P2_2 - 1; -const int MASK_3 = P2_3 - 1; -const int MASK_4 = P2_4 - 1; -const int MASK_5 = P2_5 - 1; -const int MASK_6 = P2_6 - 1; -const int MASK_7 = P2_7 - 1; -const int MASK_8 = P2_8 - 1; -const int MASK_9 = P2_9 - 1; -const int MASK_10 = P2_10 - 1; -const int MASK_11 = P2_11 - 1; -const int MASK_12 = P2_12 - 1; -const int MASK_13 = P2_13 - 1; -const int MASK_14 = P2_14 - 1; -const int MASK_15 = P2_15 - 1; -const int MASK_16 = P2_16 - 1; -const int MASK_17 = P2_17 - 1; -const int MASK_18 = P2_18 - 1; -const int MASK_19 = P2_19 - 1; -const int MASK_20 = P2_20 - 1; -const int MASK_21 = P2_21 - 1; -const int MASK_22 = P2_22 - 1; -const int MASK_23 = P2_23 - 1; -const int MASK_24 = P2_24 - 1; -const int MASK_31 = P2_31 - 1; -const int MASK_32 = P2_32 - 1; -const int MASK_63 = P2_63 - 1; -const int MASK_64 = P2_64 - 1; \ No newline at end of file diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 32cfe09f..da3c9ff1 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -8,7 +8,7 @@ trace!(MainRow, MainTrace { }); trace!(RomRow, RomTrace { - line: F, a_offset_imm0: F, a_imm1: F, b_offset_imm0: F, b_imm1: F, ind_width: F, op: F, store_offset: F, jmp_offset1: F, jmp_offset2: F, flags: F, multiplicity: F, + multiplicity: F, }); trace!(ArithRow, ArithTrace { @@ -24,7 +24,7 @@ trace!(ArithRangeTableRow, ArithRangeTableTrace { }); trace!(BinaryRow, BinaryTrace { - m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F, + m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, cout: F, result_is_a: F, use_last_carry_mode32: F, use_last_carry_mode64: F, m_op_or_ext: F, free_in_a_or_c: [F; 4], free_in_b_or_zero: [F; 4], multiplicity: F, main_step: F, }); trace!(BinaryTableRow, BinaryTableTrace { @@ -42,3 +42,7 @@ trace!(BinaryExtensionTableRow, BinaryExtensionTableTrace { trace!(SpecifiedRangesRow, SpecifiedRangesTrace { mul: [F; 1], }); + +trace!(RomRomRow, RomRomTrace { + line: F, a_offset_imm0: F, a_imm1: F, b_offset_imm0: F, b_imm1: F, ind_width: F, op: F, store_offset: F, jmp_offset1: F, jmp_offset2: F, flags: F, +}); diff --git a/pil/zisk.pil b/pil/zisk.pil index 0e97aeb6..6dc5052c 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -1,5 +1,3 @@ - -require "constants.pil" require "rom/pil/rom.pil" require "main/pil/main.pil" require "binary/pil/binary.pil" diff --git a/rom-merkle/Cargo.toml b/rom-merkle/Cargo.toml index 4558eda3..7fd14009 100644 --- a/rom-merkle/Cargo.toml +++ b/rom-merkle/Cargo.toml @@ -8,6 +8,8 @@ sm-rom = { path = "../state-machines/rom" } log = { workspace = true } stark = { workspace = true } proofman-common = { workspace = true } +zisk-pil = { path="../pil" } + p3-goldilocks = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } p3-field = { workspace = true } diff --git a/rom-merkle/src/main.rs b/rom-merkle/src/main.rs index fa856247..7729582d 100644 --- a/rom-merkle/src/main.rs +++ b/rom-merkle/src/main.rs @@ -1,30 +1,33 @@ use clap::{Arg, Command}; use colored::Colorize; use p3_goldilocks::Goldilocks; -use proofman_common::{GlobalInfo, ProofType, SetupCtx}; +use proofman_common::{get_custom_commit_trace, GlobalInfo, ProofType, SetupCtx}; use sm_rom::RomSM; use stark::StarkBufferAllocator; use std::{path::Path, sync::Arc}; use sysinfo::System; +use zisk_pil::{ROM_AIR_IDS, ZISK_AIRGROUP_ID}; fn main() { let matches = Command::new("ROM Handler") .version("1.0") .about("Compute the Merkle Root of a ROM file") - .arg(Arg::new("rom").value_name("FILE").help("The ROM file path").required(true).index(1)) + .arg( + Arg::new("rom").long("rom").value_name("FILE").help("The ROM file path").required(true), + ) .arg( Arg::new("proving_key") + .long("proving-key") .value_name("FILE") .help("The proving key folder path") - .required(true) - .index(2), + .required(true), ) .arg( - Arg::new("global_info") + Arg::new("rom_buffer") + .long("rom-buffer") .value_name("FILE") - .help("The global info file path") - .required(true) - .index(3), + .help("The rom buffer path") + .required(true), ) .get_matches(); @@ -34,9 +37,8 @@ fn main() { let proving_key_path_str = matches.get_one::("proving_key").expect("Proving key path is required"); let proving_key_path = Path::new(proving_key_path_str); - let global_info_path_str = - matches.get_one::("global_info").expect("Global info path is required"); - let global_info_path = Path::new(global_info_path_str); + let rom_buffer_str = + matches.get_one::("rom_buffer").expect("Buffer file path is required"); env_logger::builder() .format_timestamp(None) @@ -78,23 +80,24 @@ fn main() { std::process::exit(1); } - // If all checks pass, continue with the program - println!("ROM Path is valid: {}", rom_path.display()); - let buffer_allocator: Arc = Arc::new(StarkBufferAllocator::new(proving_key_path.to_path_buf())); - let global_info = GlobalInfo::new(global_info_path); + let global_info = GlobalInfo::new(proving_key_path); let sctx = Arc::new(SetupCtx::new(&global_info, &ProofType::Basic)); - if let Err(e) = - RomSM::::compute_trace(rom_path.to_path_buf(), buffer_allocator, &sctx) - { - log::error!("Error: {}", e); - std::process::exit(1); - } - - // Compute LDE and Merkelize and get the root of the rom - // TODO: Implement the logic to compute the trace + let setup = sctx.get_setup(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]); - log::info!("ROM proof successful"); + match RomSM::::compute_trace_rom_buffer( + rom_path.to_path_buf(), + buffer_allocator, + &sctx, + ) { + Ok((commit_id, buffer_rom)) => { + get_custom_commit_trace(commit_id, 0, setup, buffer_rom, rom_buffer_str.as_str()); + } + Err(e) => { + log::error!("Error: {}", e); + std::process::exit(1); + } + } } diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index 2e2a4d5e..d14d1b7b 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -55,13 +55,8 @@ impl ArithSM { self.arith_full_sm.unregister_predecessor(); } } - pub fn prove_instance( - &self, - operations: Vec, - prover_buffer: &mut [F], - offset: u64, - ) { - self.arith_full_sm.prove_instance(operations, prover_buffer, offset); + pub fn prove_instance(&self, operations: Vec, prover_buffer: &mut [F]) { + self.arith_full_sm.prove_instance(operations, prover_buffer); } } diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 082fb452..954a4030 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -60,12 +60,7 @@ impl ArithFullSM { self.arith_range_table_sm.unregister_predecessor(); } } - pub fn prove_instance( - &self, - input: Vec, - prover_buffer: &mut [F], - offset: u64, - ) { + pub fn prove_instance(&self, input: Vec, prover_buffer: &mut [F]) { let mut range_table_inputs = ArithRangeTableInputs::new(); let mut table_inputs = ArithTableInputs::new(); @@ -82,12 +77,10 @@ impl ArithFullSM { ); assert!(input.len() <= num_rows); - let mut traces = - ArithTrace::::map_buffer(prover_buffer, num_rows, offset as usize).unwrap(); + let mut traces = ArithTrace::::map_buffer(prover_buffer, num_rows, 0).unwrap(); let mut aop = ArithOperation::new(); for (irow, input) in input.iter().enumerate() { - println!("#{} ARITH op:0x{:X} a:0x{:X} b:0x{:X}", irow, input.opcode, input.a, input.b); aop.calculate(input.opcode, input.a, input.b); let mut t: ArithRow = Default::default(); for i in [0, 2] { diff --git a/state-machines/arith/src/arith_range_table.rs b/state-machines/arith/src/arith_range_table.rs index 4b3dcf1f..5f3ac89f 100644 --- a/state-machines/arith/src/arith_range_table.rs +++ b/state-machines/arith/src/arith_range_table.rs @@ -1,5 +1,5 @@ use std::sync::{ - atomic::{AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, Ordering}, Arc, Mutex, }; @@ -9,8 +9,7 @@ use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx}; use rayon::prelude::*; -use sm_common::create_prover_buffer; -use zisk_pil::{ARITH_RANGE_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{ArithRangeTableTrace, ARITH_RANGE_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct ArithRangeTableSM { wcm: Arc>, @@ -21,6 +20,7 @@ pub struct ArithRangeTableSM { // Inputs num_rows: usize, multiplicity: Mutex>, + used: AtomicBool, } impl ArithRangeTableSM { @@ -34,6 +34,7 @@ impl ArithRangeTableSM { registered_predecessors: AtomicU32::new(0), num_rows: air.num_rows(), multiplicity: Mutex::new(vec![0; air.num_rows()]), + used: AtomicBool::new(false), }; let arith_range_table_sm = Arc::new(arith_range_table_sm); @@ -47,7 +48,9 @@ impl ArithRangeTableSM { } pub fn unregister_predecessor(&self) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 && + self.used.load(Ordering::SeqCst) + { self.create_air_instance(); } } @@ -58,6 +61,7 @@ impl ArithRangeTableSM { for (row, value) in inputs { _multiplicity[row] += value; } + self.used.store(true, Ordering::Relaxed); } pub fn create_air_instance(&self) { let ectx = self.wcm.get_ectx(); @@ -73,20 +77,15 @@ impl ArithRangeTableSM { dctx.distribute_multiplicity(&mut multiplicity_, owner); if is_myne { - // Create the prover buffer - let (mut prover_buffer, offset) = create_prover_buffer( - &self.wcm.get_ectx(), - &self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - ARITH_RANGE_TABLE_AIR_IDS[0], - ); - prover_buffer[offset as usize..offset as usize + self.num_rows] + let trace: ArithRangeTableTrace<'_, _> = ArithRangeTableTrace::new(self.num_rows); + let mut prover_buffer = trace.buffer.unwrap(); + prover_buffer[0..self.num_rows] .par_iter_mut() .enumerate() .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); info!( - "{}: ··· Creating Binary basic table instance [{} rows filled 100%]", + "{}: ··· Creating Arith range basic table instance [{} rows filled 100%]", Self::MY_NAME, self.num_rows, ); @@ -111,8 +110,8 @@ impl WitnessComponent for ArithRangeTableSM { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + _ectx: Arc, + _sctx: Arc, ) { } } diff --git a/state-machines/arith/src/arith_table.rs b/state-machines/arith/src/arith_table.rs index 6805f407..dc535754 100644 --- a/state-machines/arith/src/arith_table.rs +++ b/state-machines/arith/src/arith_table.rs @@ -1,5 +1,5 @@ use std::sync::{ - atomic::{AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, Ordering}, Arc, Mutex, }; @@ -9,8 +9,7 @@ use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx}; use rayon::prelude::*; -use sm_common::create_prover_buffer; -use zisk_pil::{ARITH_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{ArithTableTrace, ARITH_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct ArithTableSM { wcm: Arc>, @@ -21,6 +20,7 @@ pub struct ArithTableSM { // Inputs num_rows: usize, multiplicity: Mutex>, + used: AtomicBool, } impl ArithTableSM { @@ -34,6 +34,7 @@ impl ArithTableSM { registered_predecessors: AtomicU32::new(0), num_rows: air.num_rows(), multiplicity: Mutex::new(vec![0; air.num_rows()]), + used: AtomicBool::new(false), }; let arith_table_sm = Arc::new(_arith_table_sm); @@ -47,7 +48,9 @@ impl ArithTableSM { } pub fn unregister_predecessor(&self) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 && + self.used.load(Ordering::SeqCst) + { self.create_air_instance(); } } @@ -60,6 +63,7 @@ impl ArithTableSM { info!("{}: ··· Processing row {} with value {}", Self::MY_NAME, row, value); _multiplicity[row] += value; } + self.used.store(true, Ordering::Relaxed); } pub fn create_air_instance(&self) { let ectx = self.wcm.get_ectx(); @@ -75,20 +79,15 @@ impl ArithTableSM { dctx.distribute_multiplicity(&mut multiplicity_, owner); if is_myne { - // Create the prover buffer - let (mut prover_buffer, offset) = create_prover_buffer( - &self.wcm.get_ectx(), - &self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - ARITH_TABLE_AIR_IDS[0], - ); - prover_buffer[offset as usize..offset as usize + self.num_rows] + let trace: ArithTableTrace<'_, _> = ArithTableTrace::new(self.num_rows); + let mut prover_buffer = trace.buffer.unwrap(); + prover_buffer[0..self.num_rows] .par_iter_mut() .enumerate() .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); info!( - "{}: ··· Creating Binary basic table instance [{} rows filled 100%]", + "{}: ··· Creating Arith basic table instance [{} rows filled 100%]", Self::MY_NAME, self.num_rows, ); @@ -113,8 +112,8 @@ impl WitnessComponent for ArithTableSM { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + _ectx: Arc, + _sctx: Arc, ) { } } diff --git a/state-machines/binary/pil/binary.pil b/state-machines/binary/pil/binary.pil index b6c48f73..de055894 100644 --- a/state-machines/binary/pil/binary.pil +++ b/state-machines/binary/pil/binary.pil @@ -59,12 +59,11 @@ require "std_lookup.pil" 0 XOR 0x22 0x22 */ -const int BINARY_ID = 20; - -airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) { +airtemplate Binary(const int N = 2**21, const int operation_bus_id) { // Default values const int bits = 64; const int bytes = bits / 8; + const int half_bytes = bytes / 2; // Main values const int input_chunks = 2; @@ -80,51 +79,70 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) // Secondary columns col witness use_last_carry; // 1 if the operation uses the last carry as its result - col witness op_is_min_max; // 1 if op ∈ {MINU,MIN,MAXU,MAX} + col witness op_is_min_max; // 1 if the operation is any of the MIN/MAX operations - const expr cout32 = carry[bytes/2-1]; + const expr mode64 = 1 - mode32; + const expr cout32 = carry[half_bytes-1]; const expr cout64 = carry[bytes-1]; - expr cout = (1-mode32) * (cout64 - cout32) + cout32; use_last_carry * (1 - use_last_carry) === 0; op_is_min_max * (1 - op_is_min_max) === 0; cout32*(1 - cout32) === 0; cout64*(1 - cout64) === 0; - // Constraints to check the correctness of each binary operation + // Auxiliary columns (primarily used to optimize lookups, but can be substituted with expressions) + col witness cout; + col witness result_is_a; + col witness use_last_carry_mode32; + col witness use_last_carry_mode64; + cout === mode64 * (cout64 - cout32) + cout32; + result_is_a === op_is_min_max * cout; + use_last_carry_mode32 === mode32 * use_last_carry; + use_last_carry_mode64 === mode64 * use_last_carry; + /* - opid last a b c cin cout - ─────────────────────────────────────────────────────────────── - m_op 0 a0 b0 c0 0 carry0 - m_op 0 a1 b1 c1 carry0 carry1 - m_op 0 a2 b2 c2 carry1 carry2 - m_op 0 a3 b3 c3 carry2 carry3 + 2*use_last_carry - m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 - m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 - m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 - m_op|EXT_32 1 a7|c3 b7|0 c7 carry6 carry7 + 2*use_last_carry + Constraints to check the correctness of each binary operation + opid last a b c cin cout + flags + ───────────────────────────────────────────────────────────────------------------------------------------------- + m_op 0 a0 b0 c0 0 carry0 + 2*op_is_min_max + 4*result_is_a + m_op 0 a1 b1 c1 carry0 carry1 + 2*op_is_min_max + 4*result_is_a + m_op 0 a2 b2 c2 carry1 carry2 + 2*op_is_min_max + 4*result_is_a + m_op 0|1 a3 b3 c3 carry2 carry3 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32 + m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0|1 a7|c3 b7|0 c7 carry6 carry7 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64 + ───────────────────────────────────────────────────────────────------------------------------------------------- + Perform, at the byte level, lookups against the binary table on inputs: + [last, m_op, a, b, cin, c, cout + flags] + where last indicates whether the byte is the last one in the operation */ - // Perform, at the byte level, lookups against the binary table on inputs: - // [last, m_op, a, b, cin, c, cout + flags] - // where last indicates whether the byte is the last one in the operation - - lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*op_is_min_max*cout]); + lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*result_is_a]); - expr _m_op = (1-mode32) * (m_op - EXT_32_OP) + EXT_32_OP; + // More auxiliary columns + col witness m_op_or_ext; + col witness free_in_a_or_c[half_bytes]; + col witness free_in_b_or_zero[half_bytes]; + m_op_or_ext === mode64 * (m_op - EXT_32_OP) + EXT_32_OP; + int j = 0; for (int i = 1; i < bytes; i++) { - expr _free_in_a = (1-mode32) * (free_in_a[i] - free_in_c[bytes/2-1]) + free_in_c[bytes/2-1]; - expr _free_in_b = (1-mode32) * free_in_b[i]; - - if (i < bytes/2 - 1) { - lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]); - } else if (i == bytes/2 - 1) { - lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*mode32]); - } else if (i < bytes - 1) { - lookup_assumes(BINARY_TABLE_ID, [0, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]); - } else { - lookup_assumes(BINARY_TABLE_ID, [1-mode32, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*(1-mode32)]); - } + if (i >= half_bytes) { + free_in_a_or_c[j] === mode64 * (free_in_a[i] - free_in_c[half_bytes-1]) + free_in_c[half_bytes-1]; + free_in_b_or_zero[j] === mode64 * free_in_b[i]; + } + + if (i < half_bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]); + } else if (i == half_bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32]); + } else if (i < bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [0, m_op_or_ext, free_in_a_or_c[j], free_in_b_or_zero[j], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]); + j++; + } else { + lookup_assumes(BINARY_TABLE_ID, [mode64, m_op_or_ext, free_in_a_or_c[j], free_in_b_or_zero[j], carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64]); + j++; + } } // Constraints to make sure that this component is called from the main component @@ -164,5 +182,5 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) col witness multiplicity; col witness main_step; - lookup_proves(OPERATION_BUS_ID, [main_step, op, ...a, ...b, ...c, (1-op_is_min_max)*cout], multiplicity); + lookup_proves(OPERATION_BUS_ID, [main_step, op, ...a, ...b, ...c, cout - result_is_a], multiplicity); } \ No newline at end of file diff --git a/state-machines/binary/pil/binary_extension.pil b/state-machines/binary/pil/binary_extension.pil index c26594d8..74a411b4 100644 --- a/state-machines/binary/pil/binary_extension.pil +++ b/state-machines/binary/pil/binary_extension.pil @@ -1,4 +1,3 @@ -require "std_permutation.pil" require "std_lookup.pil" require "std_range_check.pil" @@ -66,9 +65,7 @@ x in2[x] out[x][0] out[x][1] Result: 0xFFFF8abc 0xFFFFFFFF */ -const int BINARY_EXTENSION_ID = 21; - -airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BINARY_EXTENSION_ID) { +airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id) { const int bits = 64; const int bytes = bits / 8; diff --git a/state-machines/binary/pil/binary_extension_table.pil b/state-machines/binary/pil/binary_extension_table.pil index 871debe4..d346ff97 100644 --- a/state-machines/binary/pil/binary_extension_table.pil +++ b/state-machines/binary/pil/binary_extension_table.pil @@ -1,5 +1,5 @@ +require "std_constants.pil" require "std_lookup.pil" -require "constants.pil" // Operations Table: // Running Total diff --git a/state-machines/binary/pil/binary_table.pil b/state-machines/binary/pil/binary_table.pil index 2e17c36b..316ef9a0 100644 --- a/state-machines/binary/pil/binary_table.pil +++ b/state-machines/binary/pil/binary_table.pil @@ -1,4 +1,5 @@ -require "constants.pil"; +require "std_constants.pil"; +require "std_lookup.pil" // PIL Binary Operations Table used by Binary // Running Total @@ -65,7 +66,7 @@ airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) { 0x02:P2_18, 0x03:P2_18, // ADD,SUB 0x06:P2_17, 0x07:P2_17, // LEU,LE 0x20:P2_17, 0x21:P2_17, 0x22:P2_17, // AND,OR,XOR - 0x23:P2_11]...; // EXT_32 + EXT_32_OP:P2_11]...; // EXT_32 // NOTE: MINU/MINU_W, MIN/MIN_W, MAXU/MAXU_W, MAX/MAX_W has double size because // the result_is_a is 0 in the first half and 1 in the second half. diff --git a/state-machines/binary/src/binary.rs b/state-machines/binary/src/binary.rs index 7dcec2c8..01311cd6 100644 --- a/state-machines/binary/src/binary.rs +++ b/state-machines/binary/src/binary.rs @@ -96,12 +96,11 @@ impl BinarySM { operations: Vec, is_extension: bool, prover_buffer: &mut [F], - offset: u64, ) { if !is_extension { - self.binary_basic_sm.prove_instance(operations, prover_buffer, offset); + self.binary_basic_sm.prove_instance(operations, prover_buffer); } else { - self.binary_extension_sm.prove_instance(operations, prover_buffer, offset); + self.binary_extension_sm.prove_instance(operations, prover_buffer); } } } diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index 765ed11f..e0c8f206 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -9,13 +9,16 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; use rayon::Scope; -use sm_common::{create_prover_buffer, OpResult, Provable}; +use sm_common::{OpResult, Provable}; use std::cmp::Ordering as CmpOrdering; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::*; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; +const BYTES: usize = 8; +const HALF_BYTES: usize = BYTES / 2; + pub struct BinaryBasicSM { wcm: Arc>, @@ -158,6 +161,7 @@ impl BinaryBasicSM { let opcode = ZiskOp::try_from_code(operation.opcode).expect("Invalid ZiskOp opcode"); let mode32 = Self::opcode_is_32_bits(opcode); row.mode32 = F::from_bool(mode32); + let mode64 = F::from_bool(!mode32); // Set c_filtered let c_filtered = if mode32 { c & 0xFFFFFFFF } else { c }; @@ -667,6 +671,33 @@ impl BinaryBasicSM { _ => panic!("BinaryBasicSM::process_slice() found invalid opcode={}", operation.opcode), } + // Set cout + let cout32 = row.carry[HALF_BYTES - 1]; + let cout64 = row.carry[BYTES - 1]; + row.cout = mode64 * (cout64 - cout32) + cout32; + + // Set result_is_a + row.result_is_a = row.op_is_min_max * row.cout; + + // Set use_last_carry_mode32 and use_last_carry_mode64 + row.use_last_carry_mode32 = F::from_bool(mode32) * row.use_last_carry; + row.use_last_carry_mode64 = mode64 * row.use_last_carry; + + // Set micro opcode + row.m_op = F::from_canonical_u8(binary_basic_table_op as u8); + + // Set m_op_or_ext + let ext_32_op = F::from_canonical_u8(BinaryBasicTableOp::Ext32 as u8); + row.m_op_or_ext = mode64 * (row.m_op - ext_32_op) + ext_32_op; + + // Set free_in_a_or_c and free_in_b_or_zero + for i in 0..HALF_BYTES { + row.free_in_a_or_c[i] = mode64 * + (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) + + row.free_in_c[HALF_BYTES - 1]; + row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES]; + } + if row.use_last_carry == F::one() { // Set first and last elements row.free_in_c[7] = row.free_in_c[0]; @@ -676,26 +707,12 @@ impl BinaryBasicSM { // TODO: Find duplicates of this trace and reuse them by increasing their multiplicity. row.multiplicity = F::one(); - // Set micro opcode - row.m_op = F::from_canonical_u8(binary_basic_table_op as u8); - // Return row } - pub fn prove_instance( - &self, - operations: Vec, - prover_buffer: &mut [F], - offset: u64, - ) { - Self::prove_internal( - &self.wcm, - &self.binary_basic_table_sm, - operations, - prover_buffer, - offset, - ); + pub fn prove_instance(&self, operations: Vec, prover_buffer: &mut [F]) { + Self::prove_internal(&self.wcm, &self.binary_basic_table_sm, operations, prover_buffer); } fn prove_internal( @@ -703,7 +720,6 @@ impl BinaryBasicSM { binary_basic_table_sm: &BinaryBasicTableSM, operations: Vec, prover_buffer: &mut [F], - offset: u64, ) { timer_start_trace!(BINARY_TRACE); let pctx = wcm.get_pctx(); @@ -721,7 +737,7 @@ impl BinaryBasicSM { let mut multiplicity_table = vec![0u64; air_binary_table.num_rows()]; let mut trace_buffer = - BinaryTrace::::map_buffer(prover_buffer, air.num_rows(), offset as usize).unwrap(); + BinaryTrace::::map_buffer(prover_buffer, air.num_rows(), 0).unwrap(); for (i, operation) in operations.iter().enumerate() { let row = Self::process_slice(operation, &mut multiplicity_table); @@ -732,6 +748,7 @@ impl BinaryBasicSM { timer_start_trace!(BINARY_PADDING); let padding_row = BinaryRow:: { m_op: F::from_canonical_u8(0x20), + m_op_or_ext: F::from_canonical_u8(0x20), multiplicity: F::zero(), main_step: F::zero(), /* TODO: remove, since main_step is just for * debugging */ @@ -789,19 +806,14 @@ impl Provable for BinaryBasicSM { let sctx = self.wcm.get_sctx().clone(); - let (mut prover_buffer, offset) = create_prover_buffer( - &wcm.get_ectx(), - &wcm.get_sctx(), - ZISK_AIRGROUP_ID, - BINARY_AIR_IDS[0], - ); + let trace: BinaryTrace<'_, _> = BinaryTrace::new(air.num_rows()); + let mut prover_buffer = trace.buffer.unwrap(); Self::prove_internal( &wcm, &binary_basic_table_sm, drained_inputs, &mut prover_buffer, - offset, ); let air_instance = AirInstance::new( diff --git a/state-machines/binary/src/binary_basic_table.rs b/state-machines/binary/src/binary_basic_table.rs index 0adeb18e..374910f3 100644 --- a/state-machines/binary/src/binary_basic_table.rs +++ b/state-machines/binary/src/binary_basic_table.rs @@ -8,9 +8,8 @@ use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use rayon::prelude::*; -use sm_common::create_prover_buffer; use zisk_core::{zisk_ops::ZiskOp, P2_16, P2_17, P2_18, P2_19, P2_8, P2_9}; -use zisk_pil::{BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{BinaryTableTrace, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] @@ -247,13 +246,10 @@ impl BinaryBasicTableSM { if is_myne { // Create the prover buffer - let (mut prover_buffer, offset) = create_prover_buffer( - &self.wcm.get_ectx(), - &self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - BINARY_TABLE_AIR_IDS[0], - ); - prover_buffer[offset as usize..offset as usize + self.num_rows] + let trace: BinaryTableTrace<'_, _> = BinaryTableTrace::new(self.num_rows); + let mut prover_buffer = trace.buffer.unwrap(); + + prover_buffer[0..self.num_rows] .par_iter_mut() .enumerate() .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index eee0c226..5077107c 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -15,7 +15,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use rayon::Scope; -use sm_common::{create_prover_buffer, OpResult, Provable}; +use sm_common::{OpResult, Provable}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::*; @@ -362,19 +362,13 @@ impl BinaryExtensionSM { row } - pub fn prove_instance( - &self, - operations: Vec, - prover_buffer: &mut [F], - offset: u64, - ) { + pub fn prove_instance(&self, operations: Vec, prover_buffer: &mut [F]) { Self::prove_internal( &self.wcm, &self.binary_extension_table_sm, &self.std, operations, prover_buffer, - offset, ); } @@ -384,7 +378,6 @@ impl BinaryExtensionSM { std: &Std, operations: Vec, prover_buffer: &mut [F], - offset: u64, ) { timer_start_debug!(BINARY_EXTENSION_TRACE); let pctx = wcm.get_pctx(); @@ -405,8 +398,7 @@ impl BinaryExtensionSM { let mut multiplicity_table = vec![0u64; air_binary_extension_table.num_rows()]; let mut range_check: HashMap = HashMap::new(); let mut trace_buffer = - BinaryExtensionTrace::::map_buffer(prover_buffer, air.num_rows(), offset as usize) - .unwrap(); + BinaryExtensionTrace::::map_buffer(prover_buffer, air.num_rows(), 0).unwrap(); for (i, operation) in operations.iter().enumerate() { let row = Self::process_slice(operation, &mut multiplicity_table, &mut range_check); @@ -479,12 +471,8 @@ impl Provable for BinaryExtensio let sctx = self.wcm.get_sctx().clone(); - let (mut prover_buffer, offset) = create_prover_buffer( - &wcm.get_ectx(), - &wcm.get_sctx(), - ZISK_AIRGROUP_ID, - BINARY_EXTENSION_AIR_IDS[0], - ); + let trace: BinaryExtensionTrace<'_, _> = BinaryExtensionTrace::new(air.num_rows()); + let mut prover_buffer = trace.buffer.unwrap(); Self::prove_internal( &wcm, @@ -492,7 +480,6 @@ impl Provable for BinaryExtensio &std, drained_inputs, &mut prover_buffer, - offset, ); let air_instance = AirInstance::new( diff --git a/state-machines/binary/src/binary_extension_table.rs b/state-machines/binary/src/binary_extension_table.rs index 36c64666..6916b78a 100644 --- a/state-machines/binary/src/binary_extension_table.rs +++ b/state-machines/binary/src/binary_extension_table.rs @@ -8,9 +8,8 @@ use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use rayon::prelude::*; -use sm_common::create_prover_buffer; use zisk_core::{zisk_ops::ZiskOp, P2_11, P2_19, P2_8}; -use zisk_pil::{BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{BinaryExtensionTableTrace, BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] @@ -136,15 +135,11 @@ impl BinaryExtensionTableSM { dctx.distribute_multiplicity(&mut multiplicity_, owner); if is_myne { - // Create the prover buffer - let (mut prover_buffer, offset) = create_prover_buffer( - &self.wcm.get_ectx(), - &self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - BINARY_EXTENSION_TABLE_AIR_IDS[0], - ); + let trace: BinaryExtensionTableTrace<'_, _> = + BinaryExtensionTableTrace::new(self.num_rows); + let mut prover_buffer = trace.buffer.unwrap(); - prover_buffer[offset as usize..offset as usize + self.num_rows] + prover_buffer[0..self.num_rows] .par_iter_mut() .enumerate() .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); diff --git a/state-machines/common/src/lib.rs b/state-machines/common/src/lib.rs index bb6f10ee..9749dc32 100644 --- a/state-machines/common/src/lib.rs +++ b/state-machines/common/src/lib.rs @@ -7,27 +7,7 @@ mod worker; pub use field::*; pub use operations::*; -use proofman_common::{ExecutionCtx, SetupCtx}; -use proofman_util::create_buffer_fast; pub use provable::*; pub use session::*; pub use temp::*; pub use worker::*; - -pub fn create_prover_buffer( - ectx: &ExecutionCtx, - sctx: &SetupCtx, - airgroup_id: usize, - air_id: usize, -) -> (Vec, u64) { - // Compute buffer size using the BufferAllocator - let (buffer_size, offsets) = ectx - .buffer_allocator - .as_ref() - .get_buffer_info(sctx, airgroup_id, air_id) - .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - - let buffer = create_buffer_fast(buffer_size as usize); - - (buffer, offsets[0]) -} diff --git a/state-machines/freq-ops/src/freq_ops.rs b/state-machines/freq-ops/src/freq_ops.rs index 56bc4190..26e10338 100644 --- a/state-machines/freq-ops/src/freq_ops.rs +++ b/state-machines/freq-ops/src/freq_ops.rs @@ -48,8 +48,8 @@ impl WitnessComponent for FreqOpsSM { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + _ectx: Arc, + _sctx: Arc, ) { } } diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 027bee94..b787f3c2 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -1,4 +1,5 @@ require "std_lookup.pil" +require "std_permutation.pil" require "std_common.pil" const int BOOT_ADDR = 0x1000; @@ -274,6 +275,6 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope lookup_assumes(ROM_BUS_ID, [pc, a_offset_imm0, a_imm1, b_offset_imm0, b_imm1, ind_width, op, store_offset, jmp_offset1, jmp_offset2, rom_flags], sel: 1 - SEGMENT_L1); - direct_update(MAIN_CONTINUATION_ID, cols: [0, 0, 4096, 0, 0], bus_type: PIOP_BUS_SUM, proves: 1); - direct_update(MAIN_CONTINUATION_ID, cols: [0, 1, 0x10000000, 0, 0], bus_type: PIOP_BUS_SUM, proves: 0); + direct_global_update(MAIN_CONTINUATION_ID, cols: [0, 0, 4096, 0, 0], bus_type: PIOP_BUS_SUM, proves: 1); + direct_global_update(MAIN_CONTINUATION_ID, cols: [0, 1, 0x10000000, 0, 0], bus_type: PIOP_BUS_SUM, proves: 0); } \ No newline at end of file diff --git a/state-machines/main/src/instance_extension.rs b/state-machines/main/src/instance_extension.rs index 42aea872..551dda77 100644 --- a/state-machines/main/src/instance_extension.rs +++ b/state-machines/main/src/instance_extension.rs @@ -3,8 +3,6 @@ use zisk_core::ZiskOperationType; use ziskemu::EmuTraceStart; pub struct InstanceExtensionCtx { - pub prover_buffer: Vec, - pub offset: u64, pub op_type: ZiskOperationType, pub emu_trace_start: EmuTraceStart, pub segment_id: Option, @@ -14,22 +12,12 @@ pub struct InstanceExtensionCtx { impl InstanceExtensionCtx { pub fn new( - prover_buffer: Vec, - offset: u64, op_type: ZiskOperationType, emu_trace_start: EmuTraceStart, segment_id: Option, instance_global_idx: usize, air_instance: Option>, ) -> Self { - Self { - prover_buffer, - offset, - op_type, - emu_trace_start, - instance_global_idx, - segment_id, - air_instance, - } + Self { op_type, emu_trace_start, instance_global_idx, segment_id, air_instance } } } diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index 30240f83..a7bcef56 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -14,8 +14,8 @@ use proofman::WitnessComponent; use sm_arith::ArithSM; use sm_mem::MemSM; use zisk_pil::{ - MainRow, MainTrace, ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, - ZISK_AIRGROUP_ID, + ArithTrace, BinaryExtensionTrace, BinaryTrace, MainRow, MainTrace, ARITH_AIR_IDS, + BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, }; use ziskemu::{Emu, EmuTrace, ZiskEmulator}; @@ -80,7 +80,6 @@ impl MainSM { let segment_id = iectx.segment_id.unwrap(); let segment_trace = &vec_traces[segment_id]; - let offset = iectx.offset; let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]); let filled = segment_trace.steps.len() + 1; info!( @@ -133,8 +132,11 @@ impl MainSM { let mut emu = Emu::from_emu_trace_start(zisk_rom, &segment_trace.start_state); - let rng = offset as usize..(offset as usize + MainRow::::ROW_SIZE); - iectx.prover_buffer[rng].copy_from_slice(row0.as_slice()); + let trace = MainTrace::new(air.num_rows()); + let mut prover_buffer = trace.buffer.unwrap(); + + let rng = 0..MainRow::::ROW_SIZE; + prover_buffer[rng].copy_from_slice(row0.as_slice()); // Set Rows 1 to N of the current segment (N = maximum number of air rows) let total_rows = segment_trace.steps.len(); @@ -164,13 +166,21 @@ impl MainSM { //copy the chunk to the prover buffer let partial_buffer = partial_trace.buffer.as_ref().unwrap(); - let buffer_offset_slice = offset as usize + (slice + 1) * MainRow::::ROW_SIZE; + let buffer_offset_slice = (slice + 1) * MainRow::::ROW_SIZE; + + let slice_rows = if slice + SLICE_ROWS >= + pctx.pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows() + { + partial_buffer.len() - MainRow::::ROW_SIZE + } else { + partial_buffer.len() + }; - let rng = buffer_offset_slice..buffer_offset_slice + partial_buffer.len(); - iectx.prover_buffer[rng].copy_from_slice(partial_buffer); + let rng = buffer_offset_slice..buffer_offset_slice + slice_rows; + prover_buffer[rng].copy_from_slice(&partial_buffer[..slice_rows]); } - let buffer = std::mem::take(&mut iectx.prover_buffer); + let buffer = std::mem::take(&mut prover_buffer); let sctx = self.wcm.get_sctx(); let mut air_instance = AirInstance::new( sctx.clone(), @@ -183,9 +193,8 @@ impl MainSM { let main_last_segment = F::from_bool(segment_id == vec_traces.len() - 1); let main_segment = F::from_canonical_usize(segment_id); - air_instance.set_airvalue(&sctx, "Main.main_last_segment", main_last_segment); - air_instance.set_airvalue(&sctx, "Main.main_segment", main_segment); - + air_instance.set_airvalue("Main.main_last_segment", None, main_last_segment); + air_instance.set_airvalue("Main.main_segment", None, main_segment); iectx.air_instance = Some(air_instance); } @@ -210,11 +219,14 @@ impl MainSM { timer_start_debug!(PROVE_ARITH); - self.arith_sm.prove_instance(inputs, &mut iectx.prover_buffer, iectx.offset); + let trace = ArithTrace::new(air.num_rows()); + let mut prover_buffer = trace.buffer.unwrap(); + + self.arith_sm.prove_instance(inputs, &mut prover_buffer); timer_stop_and_log_debug!(PROVE_ARITH); timer_start_debug!(CREATE_AIR_INSTANCE); - let buffer = std::mem::take(&mut iectx.prover_buffer); + let buffer = std::mem::take(&mut prover_buffer); iectx.air_instance = Some(AirInstance::new( self.wcm.get_sctx(), ZISK_AIRGROUP_ID, @@ -245,11 +257,14 @@ impl MainSM { timer_stop_and_log_debug!(PROCESS_BINARY); timer_start_debug!(PROVE_BINARY); - self.binary_sm.prove_instance(inputs, false, &mut iectx.prover_buffer, iectx.offset); + let trace = BinaryTrace::new(air.num_rows()); + let mut prover_buffer = trace.buffer.unwrap(); + + self.binary_sm.prove_instance(inputs, false, &mut prover_buffer); timer_stop_and_log_debug!(PROVE_BINARY); timer_start_debug!(CREATE_AIR_INSTANCE); - let buffer = std::mem::take(&mut iectx.prover_buffer); + let buffer = std::mem::take(&mut prover_buffer); iectx.air_instance = Some(AirInstance::new( self.wcm.get_sctx(), ZISK_AIRGROUP_ID, @@ -277,9 +292,12 @@ impl MainSM { air.num_rows(), ); - self.binary_sm.prove_instance(inputs, true, &mut iectx.prover_buffer, iectx.offset); + let trace = BinaryExtensionTrace::new(air.num_rows()); + let mut prover_buffer = trace.buffer.unwrap(); + + self.binary_sm.prove_instance(inputs, true, &mut prover_buffer); - let buffer = std::mem::take(&mut iectx.prover_buffer); + let buffer = std::mem::take(&mut prover_buffer); iectx.air_instance = Some(AirInstance::new( self.wcm.get_sctx(), ZISK_AIRGROUP_ID, diff --git a/state-machines/mem/src/mem.rs b/state-machines/mem/src/mem.rs index 391bca7b..065b1841 100644 --- a/state-machines/mem/src/mem.rs +++ b/state-machines/mem/src/mem.rs @@ -72,8 +72,8 @@ impl WitnessComponent for MemSM { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + _ectx: Arc, + _sctx: Arc, ) { } } diff --git a/state-machines/mem/src/mem_aligned.rs b/state-machines/mem/src/mem_aligned.rs index 47feebfb..1a126e3c 100644 --- a/state-machines/mem/src/mem_aligned.rs +++ b/state-machines/mem/src/mem_aligned.rs @@ -68,8 +68,8 @@ impl WitnessComponent for MemAlignedSM { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + _ectx: Arc, + _sctx: Arc, ) { } } diff --git a/state-machines/mem/src/mem_unaligned.rs b/state-machines/mem/src/mem_unaligned.rs index 9d47a135..fde238e3 100644 --- a/state-machines/mem/src/mem_unaligned.rs +++ b/state-machines/mem/src/mem_unaligned.rs @@ -70,8 +70,8 @@ impl WitnessComponent for MemUnalignedSM { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + _ectx: Arc, + _sctx: Arc, ) { } } diff --git a/state-machines/publics.json b/state-machines/publics.json new file mode 100644 index 00000000..a7b4243d --- /dev/null +++ b/state-machines/publics.json @@ -0,0 +1,6 @@ +{ + "nPublics": 4, + "definitions": [ + { "name": "rom_root", "initialPos": 0, "chunks": [4, 64] } + ] +} \ No newline at end of file diff --git a/state-machines/quick-ops/src/quick_ops.rs b/state-machines/quick-ops/src/quick_ops.rs index a91b6fd6..901a508e 100644 --- a/state-machines/quick-ops/src/quick_ops.rs +++ b/state-machines/quick-ops/src/quick_ops.rs @@ -58,8 +58,8 @@ impl WitnessComponent for QuickOpsSM { _stage: u32, _air_instance: Option, _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, + _ectx: Arc, + _sctx: Arc, ) { } } diff --git a/state-machines/rom/pil/rom.pil b/state-machines/rom/pil/rom.pil index 2c2011e0..578333bc 100644 --- a/state-machines/rom/pil/rom.pil +++ b/state-machines/rom/pil/rom.pil @@ -2,19 +2,22 @@ require "std_lookup.pil" const int ROM_BUS_ID = 7890; +public rom_root[4]; + airtemplate Rom(int N = 2**21, int stack_enabled = 0, const int rom_bus_id = ROM_BUS_ID) { + commit stage(0) public(rom_root) rom; - col witness line; - col witness a_offset_imm0; - col witness a_imm1; - col witness b_offset_imm0; - col witness b_imm1; - col witness ind_width; - col witness op; - col witness store_offset; - col witness jmp_offset1; - col witness jmp_offset2; - col witness flags; + col rom line; + col rom a_offset_imm0; + col rom a_imm1; + col rom b_offset_imm0; + col rom b_imm1; + col rom ind_width; + col rom op; + col rom store_offset; + col rom jmp_offset1; + col rom jmp_offset2; + col rom flags; col witness multiplicity; diff --git a/state-machines/rom/src/rom.rs b/state-machines/rom/src/rom.rs index 3fc4f9b1..ad79b871 100644 --- a/state-machines/rom/src/rom.rs +++ b/state-machines/rom/src/rom.rs @@ -6,10 +6,11 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{AirInstance, BufferAllocator, SetupCtx}; use proofman_util::create_buffer_fast; -use zisk_core::{Riscv2zisk, ZiskPcHistogram, ZiskRom, SRC_IMM}; -use zisk_pil::{Pilout, RomRow, RomTrace, MAIN_AIR_IDS, ROM_AIR_IDS, ZISK_AIRGROUP_ID}; -//use ziskemu::ZiskEmulatorErr; use std::error::Error; +use zisk_core::{Riscv2zisk, ZiskPcHistogram, ZiskRom, SRC_IMM}; +use zisk_pil::{ + Pilout, RomRomRow, RomRomTrace, RomRow, RomTrace, MAIN_AIR_IDS, ROM_AIR_IDS, ZISK_AIRGROUP_ID, +}; pub struct RomSM { wcm: Arc>, @@ -31,145 +32,23 @@ impl RomSM { rom: &ZiskRom, pc_histogram: ZiskPcHistogram, ) -> Result<(), Box> { - let buffer_allocator = self.wcm.get_ectx().buffer_allocator.clone(); - let sctx = self.wcm.get_sctx(); - if pc_histogram.end_pc == 0 { panic!("RomSM::prove() detected pc_histogram.end_pc == 0"); // TODO: return an error } - let main_trace_len = - self.wcm.get_pctx().pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows() as u64; - - let (prover_buffer, _, air_id) = - Self::compute_trace_rom(rom, buffer_allocator, &sctx, pc_histogram, main_trace_len)?; - - let air_instance = - AirInstance::new(sctx.clone(), ZISK_AIRGROUP_ID, air_id, None, prover_buffer); - let (is_mine, instance_gid) = - self.wcm.get_ectx().dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1); - if is_mine { - self.wcm - .get_pctx() - .air_instance_repo - .add_air_instance(air_instance, Some(instance_gid)); - } - - Ok(()) - } - pub fn compute_trace( - rom_path: PathBuf, - buffer_allocator: Arc>, - sctx: &SetupCtx, - ) -> Result<(Vec, u64, usize), Box> { - // Get the ELF file path as a string - let elf_filename: String = rom_path.to_str().unwrap().into(); - println!("Proving ROM for ELF file={}", elf_filename); - - // Load and parse the ELF file, and transpile it into a ZisK ROM using Riscv2zisk - - // Create an instance of the RISCV -> ZisK program converter - let riscv2zisk = Riscv2zisk::new(elf_filename, String::new(), String::new(), String::new()); - - // Convert program to rom - let rom_result = riscv2zisk.run(); - if rom_result.is_err() { - //return Err(ZiskEmulatorErr::Unknown(zisk_rom.err().unwrap().to_string())); - panic!("RomSM::prover() failed converting elf to rom"); - } - let rom = rom_result.unwrap(); - - let empty_pc_histogram = ZiskPcHistogram::default(); - - Self::compute_trace_rom(&rom, buffer_allocator, sctx, empty_pc_histogram, 0) - } - - pub fn compute_trace_rom( - rom: &ZiskRom, - buffer_allocator: Arc>, - sctx: &SetupCtx, - pc_histogram: ZiskPcHistogram, - main_trace_len: u64, - ) -> Result<(Vec, u64, usize), Box> { - let pilout = Pilout::pilout(); - let sizes = ( - pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(), - // pilout.get_air(ZISK_AIRGROUP_ID, ROM_M_AIR_IDS[0]).num_rows(), - // pilout.get_air(ZISK_AIRGROUP_ID, ROM_L_AIR_IDS[0]).num_rows(), - ); - - let number_of_instructions = rom.insts.len(); - - Self::create_rom_s( - sizes.0, - rom, - number_of_instructions, - buffer_allocator, - sctx, - pc_histogram, - main_trace_len, - ) - // match number_of_instructions { - // n if n <= sizes.0 => Self::create_rom_s( - // sizes.0, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // n if n <= sizes.1 => Self::create_rom_m( - // sizes.1, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // n if n < sizes.2 => Self::create_rom_l( - // sizes.2, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // _ => panic!("RomSM::compute_trace() found rom too big size={}", - // number_of_instructions), } - } - - fn create_rom_s( - rom_s_size: usize, - rom: &zisk_core::ZiskRom, - number_of_instructions: usize, - buffer_allocator: Arc>, - sctx: &SetupCtx, - pc_histogram: ZiskPcHistogram, - main_trace_len: u64, - ) -> Result<(Vec, u64, usize), Box> { - // Set trace size - let trace_size = rom_s_size; - // Allocate a prover buffer - let (buffer_size, offsets) = buffer_allocator - .get_buffer_info(sctx, ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]) - .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - let mut prover_buffer = create_buffer_fast(buffer_size as usize); + let buffer_allocator = self.wcm.get_ectx().buffer_allocator.clone(); + let sctx = self.wcm.get_sctx(); // Create an empty ROM trace - let mut rom_trace = - RomTrace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - .expect("RomSM::compute_trace() failed mapping buffer to ROMSRow"); + let pilout = Pilout::pilout(); + let num_rows = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(); + + let mut rom_trace = RomTrace::new(num_rows); // For every instruction in the rom, fill its corresponding ROM trace - //for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { - let keys = rom.insts.keys(); - let sorted_keys = keys.sorted(); - let mut i = 0; - for key in sorted_keys { + let main_trace_len = pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows() as u64; + for (i, key) in rom.insts.keys().sorted().enumerate() { // Get the Zisk instruction let inst = &rom.insts[key].i; @@ -190,6 +69,88 @@ impl RomSM { continue; // We skip those pc's that are not used in this execution } } + rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); + } + + // Padd with zeroes + for i in rom.insts.len()..num_rows { + rom_trace[i] = RomRow::default(); + } + + let mut air_instance = AirInstance::new( + sctx.clone(), + ZISK_AIRGROUP_ID, + ROM_AIR_IDS[0], + None, + rom_trace.buffer.unwrap(), + ); + + match self + .wcm + .get_ectx() + .cached_buffers_path + .as_ref() + .and_then(|cached_buffers| cached_buffers.get("rom").cloned()) + { + Some(buffer_path) => { + let (_, _, commit_id) = buffer_allocator + .clone() + .get_buffer_info_custom_commit(&sctx, ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], "rom") + .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); + + air_instance.set_custom_commit_cached_file(&sctx, commit_id, buffer_path); + } + None => { + let (commit_id_rom, prover_buffer_rom) = + Self::compute_trace_rom(rom, buffer_allocator.clone(), &sctx)?; + + air_instance.set_custom_commit_id_buffer(&sctx, prover_buffer_rom, commit_id_rom); + } + } + + let (commit_id_rom, prover_buffer_rom) = + Self::compute_trace_rom(rom, buffer_allocator.clone(), &sctx)?; + + air_instance.set_custom_commit_id_buffer(&sctx, prover_buffer_rom, commit_id_rom); + + let (is_mine, instance_gid) = self.wcm.get_ectx().dctx.write().unwrap().add_instance( + ZISK_AIRGROUP_ID, + ROM_AIR_IDS[0], + 1, + ); + if is_mine { + self.wcm + .get_pctx() + .air_instance_repo + .add_air_instance(air_instance, Some(instance_gid)); + } + + Ok(()) + } + + pub fn compute_trace_rom( + rom: &ZiskRom, + buffer_allocator: Arc, + sctx: &SetupCtx, + ) -> Result<(u64, Vec), Box> { + // Allocate a prover buffer + let (buffer_size_rom, offsets_rom, commit_id) = buffer_allocator + .get_buffer_info_custom_commit(sctx, ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], "rom") + .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); + + // Create an empty ROM trace + let pilout = Pilout::pilout(); + let trace_rows = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(); + let mut prover_buffer = create_buffer_fast(buffer_size_rom as usize); + + let mut rom_trace = + RomRomTrace::::map_buffer(&mut prover_buffer, trace_rows, offsets_rom[0] as usize) + .expect("RomRootSM::compute_trace() failed mapping buffer to ROMSRow"); + + // For every instruction in the rom, fill its corresponding ROM trace + for (i, key) in rom.insts.keys().sorted().enumerate() { + // Get the Zisk instruction + let inst = &rom.insts[key].i; // Convert the i64 offsets to F let jmp_offset1 = if inst.jmp_offset1 >= 0 { @@ -226,239 +187,41 @@ impl RomSM { rom_trace[i].b_offset_imm0 = b_offset_imm0; rom_trace[i].b_imm1 = F::from_canonical_u64(if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 }); - //rom_trace[i].b_src_ind = - // F::from_canonical_u64(if inst.b_src == SRC_IND { 1 } else { 0 }); rom_trace[i].ind_width = F::from_canonical_u64(inst.ind_width); rom_trace[i].op = F::from_canonical_u8(inst.op); rom_trace[i].store_offset = store_offset; rom_trace[i].jmp_offset1 = jmp_offset1; rom_trace[i].jmp_offset2 = jmp_offset2; rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); - rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - /*println!( - "ROM SM [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], {}", - inst.paddr, - inst.a_offset_imm0, - if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 }, - inst.b_offset_imm0, - if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 }, - if inst.b_src == SRC_IND { 1 } else { 0 }, - inst.ind_width, - inst.op, - inst.store_offset as u64, - inst.jmp_offset1 as u64, - inst.jmp_offset2 as u64, - inst.get_flags(), - multiplicity, - );*/ - i += 1; } // Padd with zeroes - for i in number_of_instructions..trace_size { - rom_trace[i] = RomRow::default(); + for i in rom.insts.len()..trace_rows { + rom_trace[i] = RomRomRow::default(); } - Ok((prover_buffer, offsets[0], ROM_AIR_IDS[0])) + Ok((commit_id, prover_buffer)) } - // fn create_rom_m( - // rom_m_size: usize, - // rom: &zisk_core::ZiskRom, - // number_of_instructions: usize, - // buffer_allocator: Arc, - // sctx: &SetupCtx, - // pc_histogram: ZiskPcHistogram, - // main_trace_len: u64, - // ) -> Result<(Vec, u64, usize), Box> { - // // Set trace size - // let trace_size = rom_m_size; - - // // Allocate a prover buffer - // let (buffer_size, offsets) = buffer_allocator - // .get_buffer_info(sctx, ZISK_AIRGROUP_ID, ROM_M_AIR_IDS[0]) - // .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - // let mut prover_buffer = create_buffer_fast(buffer_size as usize); - - // // Create an empty ROM trace - // let mut rom_trace = - // RomM1Trace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - // .expect("RomSM::compute_trace() failed mapping buffer to ROMMRow"); - - // // For every instruction in the rom, fill its corresponding ROM trace - // for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { - // // Get the Zisk instruction - // let inst = inst_builder.1.i; - - // // Calculate the multiplicity, i.e. the number of times this pc is used in this - // // execution - // let mut multiplicity: u64; - // if pc_histogram.map.is_empty() { - // multiplicity = 1; // If the histogram is empty, we use 1 for all pc's - // } else { - // let counter = pc_histogram.map.get(&inst.paddr); - // if counter.is_some() { - // multiplicity = *counter.unwrap(); - // if inst.paddr == pc_histogram.end_pc { - // multiplicity += main_trace_len - 1 - (pc_histogram.steps % - // main_trace_len); } - // } else { - // continue; // We skip those pc's that are not used in this execution - // } - // } - - // // Convert the i64 offsets to F - // let jmp_offset1 = if inst.jmp_offset1 >= 0 { - // F::from_canonical_u64(inst.jmp_offset1 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset1) as u64)) - // }; - // let jmp_offset2 = if inst.jmp_offset2 >= 0 { - // F::from_canonical_u64(inst.jmp_offset2 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset2) as u64)) - // }; - // let store_offset = if inst.store_offset >= 0 { - // F::from_canonical_u64(inst.store_offset as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.store_offset) as u64)) - // }; - // let a_offset_imm0 = if inst.a_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.a_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.a_offset_imm0 as i64)) as u64)) - // }; - // let b_offset_imm0 = if inst.b_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.b_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.b_offset_imm0 as i64)) as u64)) - // }; - - // // Fill the rom trace row fields - // rom_trace[i].line = F::from_canonical_u64(inst.paddr); // TODO: unify names: pc, - // paddr, line rom_trace[i].a_offset_imm0 = a_offset_imm0; - // rom_trace[i].a_imm1 = - // F::from_canonical_u64(if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 - // }); rom_trace[i].b_offset_imm0 = b_offset_imm0; - // rom_trace[i].b_imm1 = - // F::from_canonical_u64(if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 - // }); //rom_trace[i].b_src_ind = - // // F::from_canonical_u64(if inst.b_src == SRC_IND { 1 } else { 0 }); - // rom_trace[i].ind_width = F::from_canonical_u64(inst.ind_width); - // rom_trace[i].op = F::from_canonical_u8(inst.op); - // rom_trace[i].store_offset = store_offset; - // rom_trace[i].jmp_offset1 = jmp_offset1; - // rom_trace[i].jmp_offset2 = jmp_offset2; - // rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); - // rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - // } - - // // Padd with zeroes - // for i in number_of_instructions..trace_size { - // rom_trace[i] = RomM1Row::default(); - // } - - // Ok((prover_buffer, offsets[0], ROM_M_AIR_IDS[0])) - // } - - // fn create_rom_l( - // rom_l_size: usize, - // rom: &zisk_core::ZiskRom, - // number_of_instructions: usize, - // buffer_allocator: Arc, - // sctx: &SetupCtx, - // pc_histogram: ZiskPcHistogram, - // main_trace_len: u64, - // ) -> Result<(Vec, u64, usize), Box> { - // // Set trace size - // let trace_size = rom_l_size; - - // // Allocate a prover buffer - // let (buffer_size, offsets) = buffer_allocator - // .get_buffer_info(sctx, ZISK_AIRGROUP_ID, ROM_L_AIR_IDS[0]) - // .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - // let mut prover_buffer = create_buffer_fast(buffer_size as usize); - - // // Create an empty ROM trace - // let mut rom_trace = - // RomL2Trace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - // .expect("RomSM::compute_trace() failed mapping buffer to ROMLRow"); - - // // For every instruction in the rom, fill its corresponding ROM trace - // for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { - // // Get the Zisk instruction - // let inst = inst_builder.1.i; - - // // Calculate the multiplicity, i.e. the number of times this pc is used in this - // // execution - // let mut multiplicity: u64; - // if pc_histogram.map.is_empty() { - // multiplicity = 1; // If the histogram is empty, we use 1 for all pc's - // } else { - // let counter = pc_histogram.map.get(&inst.paddr); - // if counter.is_some() { - // multiplicity = *counter.unwrap(); - // if inst.paddr == pc_histogram.end_pc { - // multiplicity += main_trace_len - 1 - (pc_histogram.steps % - // main_trace_len); } - // } else { - // continue; // We skip those pc's that are not used in this execution - // } - // } - - // // Convert the i64 offsets to F - // let jmp_offset1 = if inst.jmp_offset1 >= 0 { - // F::from_canonical_u64(inst.jmp_offset1 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset1) as u64)) - // }; - // let jmp_offset2 = if inst.jmp_offset2 >= 0 { - // F::from_canonical_u64(inst.jmp_offset2 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset2) as u64)) - // }; - // let store_offset = if inst.store_offset >= 0 { - // F::from_canonical_u64(inst.store_offset as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.store_offset) as u64)) - // }; - // let a_offset_imm0 = if inst.a_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.a_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.a_offset_imm0 as i64)) as u64)) - // }; - // let b_offset_imm0 = if inst.b_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.b_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.b_offset_imm0 as i64)) as u64)) - // }; - - // // Fill the rom trace row fields - // rom_trace[i].line = F::from_canonical_u64(inst.paddr); // TODO: unify names: pc, - // paddr, line rom_trace[i].a_offset_imm0 = a_offset_imm0; - // rom_trace[i].a_imm1 = - // F::from_canonical_u64(if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 - // }); rom_trace[i].b_offset_imm0 = b_offset_imm0; - // rom_trace[i].b_imm1 = - // F::from_canonical_u64(if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 - // }); //rom_trace[i].b_src_ind = - // // F::from_canonical_u64(if inst.b_src == SRC_IND { 1 } else { 0 }); - // rom_trace[i].ind_width = F::from_canonical_u64(inst.ind_width); - // rom_trace[i].op = F::from_canonical_u8(inst.op); - // rom_trace[i].store_offset = store_offset; - // rom_trace[i].jmp_offset1 = jmp_offset1; - // rom_trace[i].jmp_offset2 = jmp_offset2; - // rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); - // rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - // } - - // // Padd with zeroes - // for i in number_of_instructions..trace_size { - // rom_trace[i] = RomL2Row::default(); - // } - - // Ok((prover_buffer, offsets[0], ROM_L_AIR_IDS[0])) - // } + pub fn compute_trace_rom_buffer( + rom_path: PathBuf, + buffer_allocator: Arc, + sctx: &SetupCtx, + ) -> Result<(u64, Vec), Box> { + // Get the ELF file path as a string + let elf_filename: String = rom_path.to_str().unwrap().into(); + println!("Proving ROM for ELF file={}", elf_filename); + + // Load and parse the ELF file, and transpile it into a ZisK ROM using Riscv2zisk + + // Create an instance of the RISCV -> ZisK program converter + let riscv2zisk = Riscv2zisk::new(elf_filename, String::new(), String::new(), String::new()); + + // Convert program to rom + let rom = riscv2zisk.run().expect("RomSM::prover() failed converting elf to rom"); + + Self::compute_trace_rom(&rom, buffer_allocator, sctx) + } } impl WitnessComponent for RomSM {} diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index b77f6bfe..9e3834a5 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -8,7 +8,6 @@ use rayon::prelude::*; use sm_arith::ArithSM; use sm_binary::BinarySM; -use sm_common::create_prover_buffer; use sm_main::{InstanceExtensionCtx, MainSM}; use sm_mem::MemSM; use sm_rom::RomSM; @@ -100,8 +99,8 @@ impl ZiskExecutor { rom_path: &Path, public_inputs_path: &Path, pctx: Arc>, - ectx: Arc>, - sctx: Arc>, + ectx: Arc, + _sctx: Arc, ) { let air_main = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]); @@ -193,10 +192,7 @@ impl ZiskExecutor { }; if let (true, global_idx) = dctx.add_instance(airgroup_id, air_id, 1) { - let (buffer, offset) = create_prover_buffer::(&ectx, &sctx, airgroup_id, air_id); instances_extension_ctx.push(InstanceExtensionCtx::new( - buffer, - offset, emu_slice.op_type, emu_slice.emu_trace_start.clone(), segment_id, diff --git a/witness-computation/src/zisk_lib.rs b/witness-computation/src/zisk_lib.rs index 02a78d12..ac9d8ec8 100644 --- a/witness-computation/src/zisk_lib.rs +++ b/witness-computation/src/zisk_lib.rs @@ -47,12 +47,7 @@ impl ZiskWitness { }) } - fn initialize( - &mut self, - pctx: Arc>, - ectx: Arc>, - sctx: Arc>, - ) { + fn initialize(&mut self, pctx: Arc>, ectx: Arc, sctx: Arc) { let wcm = WitnessManager::new(pctx, ectx, sctx); let wcm = Arc::new(wcm); @@ -65,8 +60,8 @@ impl WitnessLibrary for ZiskWitness { fn start_proof( &mut self, pctx: Arc>, - ectx: Arc>, - sctx: Arc>, + ectx: Arc, + sctx: Arc, ) { self.initialize(pctx.clone(), ectx.clone(), sctx.clone()); @@ -76,7 +71,7 @@ impl WitnessLibrary for ZiskWitness { fn end_proof(&mut self) { self.wcm.get().unwrap().end_proof(); } - fn execute(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { + fn execute(&self, pctx: Arc>, ectx: Arc, sctx: Arc) { timer_start_info!(EXECUTE); self.executor.get().unwrap().execute( &self.rom_path, @@ -92,8 +87,8 @@ impl WitnessLibrary for ZiskWitness { &mut self, stage: u32, pctx: Arc>, - ectx: Arc>, - sctx: Arc>, + ectx: Arc, + sctx: Arc, ) { self.wcm.get().unwrap().calculate_witness(stage, pctx, ectx, sctx); } From 80192db01dbc30116fe9c7fdeb846e9c705426bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip=20Ardevol?= Date: Wed, 11 Dec 2024 13:02:01 +0100 Subject: [PATCH 2/6] Division check for arith component (#181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Removing unnecessary code * Feature/custom commits (#166) * Custom cols rom (#159) Custom cols working --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> * Cached custom commits * Updating proofman to 0.0.12 * Cargo fmt * Cargo fmt * Fix cargo clippy * Rom trace is now deterministic * cargo fmt * Global constraints verifying again * Optimizing the binary component (#167) * Optimizing the binary * Updating the executor * Updating to 0.0.13 * Not creating unnecessary instances of arith tables * Pil2-proofman 0.0.14 --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Co-authored-by: Héctor Masip Ardevol * Zisk working with last proofman version * Updating book and Cargo.toml to point to 0.0.16 proofman * Table augmented * New opcodes introduced to binary * Added the arith-binary lookup * Added the debug mode to the verify all script * working except for GT op * Adding the gt op * rebasing update proofman * GT fully working * cargo fmt, cargo fix * Cargo * remove comments * minor changes --------- Co-authored-by: Roger Taulé Buxadera <55488871+RogerTaule@users.noreply.github.com> Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Co-authored-by: RogerTaule --- Cargo.lock | 118 ++- core/src/zisk_ops.rs | 68 +- pil/src/pil_helpers/pilout.rs | 2 +- pil/src/pil_helpers/traces.rs | 2 +- state-machines/arith/Cargo.toml | 1 + state-machines/arith/pil/arith.pil | 15 +- state-machines/arith/src/arith.rs | 12 +- state-machines/arith/src/arith_full.rs | 63 +- state-machines/binary/pil/binary.pil | 114 ++- .../binary/pil/binary_extension.pil | 18 +- .../binary/pil/binary_extension_table.pil | 59 +- state-machines/binary/pil/binary_table.pil | 280 ++++--- state-machines/binary/src/binary.rs | 60 +- state-machines/binary/src/binary_basic.rs | 696 ++++++++++++------ .../binary/src/binary_basic_table.rs | 153 ++-- state-machines/binary/src/binary_extension.rs | 16 +- .../binary/src/binary_extension_table.rs | 31 +- state-machines/main/pil/main.pil | 4 +- tools/verify_all.sh | 17 +- witness-computation/src/executor.rs | 2 +- 20 files changed, 1069 insertions(+), 662 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 84d55e9f..e59ff2b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -198,9 +198,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" +checksum = "27f657647bcff5394bf56c7317665bbf790a137a50eaaa5c6bfbb9e27a518f2d" dependencies = [ "jobserver", "libc", @@ -248,9 +248,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.22" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69371e34337c4c984bbe322360c2547210bf632eb2814bbe78a6e87a2935bd2b" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" dependencies = [ "clap_builder", "clap_derive", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.22" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e24c1b4099818523236a8ca881d2b45db98dadfb4625cf6608c12069fcbbde1" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" dependencies = [ "anstream", "anstyle", @@ -282,9 +282,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "colorchoice" @@ -506,9 +506,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "findshlibs" @@ -1048,9 +1048,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" +checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" dependencies = [ "once_cell", "wasm-bindgen", @@ -1070,9 +1070,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.167" +version = "0.2.168" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" [[package]] name = "libgit2-sys" @@ -1675,9 +1675,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f" +checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec" dependencies = [ "bytes", "prost-derive", @@ -1685,11 +1685,10 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" +checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b" dependencies = [ - "bytes", "heck", "itertools 0.13.0", "log", @@ -1706,9 +1705,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" +checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3" dependencies = [ "anyhow", "itertools 0.13.0", @@ -1719,9 +1718,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670" +checksum = "cc2f1e56baa61e93533aebc21af4d2134b70f66275e0fcdf3cbe43d77ff7e8fc" dependencies = [ "prost", ] @@ -1748,7 +1747,7 @@ dependencies = [ "rustc-hash", "rustls", "socket2", - "thiserror 2.0.4", + "thiserror 2.0.6", "tokio", "tracing", ] @@ -1767,7 +1766,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.4", + "thiserror 2.0.6", "tinyvec", "tracing", "web-time", @@ -1775,9 +1774,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" +checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527" dependencies = [ "cfg_aliases", "libc", @@ -2002,15 +2001,15 @@ checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" [[package]] name = "rustix" -version = "0.38.41" +version = "0.38.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" +checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" dependencies = [ "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2085,18 +2084,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", @@ -2165,6 +2164,7 @@ dependencies = [ "proofman-macros", "proofman-util", "rayon", + "sm-binary", "sm-common", "zisk-core", "zisk-pil", @@ -2463,11 +2463,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.4" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f49a1853cf82743e3b7950f77e0f4d622ca36cf4317cba00c767838bac8d490" +checksum = "8fec2a1820ebd077e2b90c4df007bebf344cd394098a13c563957d0afc83ea47" dependencies = [ - "thiserror-impl 2.0.4", + "thiserror-impl 2.0.6", ] [[package]] @@ -2483,9 +2483,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.4" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8381894bb3efe0c4acac3ded651301ceee58a15d47c2e34885ed1908ad667061" +checksum = "d65750cab40f4ff1929fb1ba509e9914eb756131cef4210da8d5d700d26f6312" dependencies = [ "proc-macro2", "quote", @@ -2600,20 +2600,19 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ "rustls", - "rustls-pki-types", "tokio", ] [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", @@ -2794,9 +2793,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" +checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" dependencies = [ "cfg-if", "once_cell", @@ -2805,13 +2804,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" +checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn", @@ -2820,9 +2818,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.47" +version = "0.4.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dfaf8f50e5f293737ee323940c7d8b08a66a95a419223d9f41610ca08b0833d" +checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" dependencies = [ "cfg-if", "js-sys", @@ -2833,9 +2831,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" +checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2843,9 +2841,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" +checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", @@ -2856,9 +2854,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" +checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" [[package]] name = "wasm-streams" @@ -2875,9 +2873,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a98bc3c33f0fe7e59ad7cd041b89034fa82a7c2d4365ca538dda6cdaf513863c" +checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/core/src/zisk_ops.rs b/core/src/zisk_ops.rs index f7763cbf..71efe57f 100644 --- a/core/src/zisk_ops.rs +++ b/core/src/zisk_ops.rs @@ -235,32 +235,32 @@ macro_rules! define_ops { define_ops! { (Flag, "flag", Internal, 0, 0x00, opc_flag, op_flag), (CopyB, "copyb", Internal, 0, 0x01, opc_copyb, op_copyb), - (SignExtendB, "signextend_b", BinaryE, 109, 0x23, opc_signextend_b, op_signextend_b), - (SignExtendH, "signextend_h", BinaryE, 109, 0x24, opc_signextend_h, op_signextend_h), - (SignExtendW, "signextend_w", BinaryE, 109, 0x25, opc_signextend_w, op_signextend_w), - (Add, "add", Binary, 77, 0x02, opc_add, op_add), - (AddW, "add_w", Binary, 77, 0x12, opc_add_w, op_add_w), - (Sub, "sub", Binary, 77, 0x03, opc_sub, op_sub), - (SubW, "sub_w", Binary, 77, 0x13, opc_sub_w, op_sub_w), - (Sll, "sll", BinaryE, 109, 0x0d, opc_sll, op_sll), - (SllW, "sll_w", BinaryE, 109, 0x1d, opc_sll_w, op_sll_w), - (Sra, "sra", BinaryE, 109, 0x0f, opc_sra, op_sra), - (Srl, "srl", BinaryE, 109, 0x0e, opc_srl, op_srl), - (SraW, "sra_w", BinaryE, 109, 0x1f, opc_sra_w, op_sra_w), - (SrlW, "srl_w", BinaryE, 109, 0x1e, opc_srl_w, op_srl_w), - (Eq, "eq", Binary, 77, 0x08, opc_eq, op_eq), - (EqW, "eq_w", Binary, 77, 0x18, opc_eq_w, op_eq_w), - (Ltu, "ltu", Binary, 77, 0x04, opc_ltu, op_ltu), - (Lt, "lt", Binary, 77, 0x05, opc_lt, op_lt), - (LtuW, "ltu_w", Binary, 77, 0x14, opc_ltu_w, op_ltu_w), - (LtW, "lt_w", Binary, 77, 0x15, opc_lt_w, op_lt_w), - (Leu, "leu", Binary, 77, 0x06, opc_leu, op_leu), - (Le, "le", Binary, 77, 0x07, opc_le, op_le), - (LeuW, "leu_w", Binary, 77, 0x16, opc_leu_w, op_leu_w), - (LeW, "le_w", Binary, 77, 0x17, opc_le_w, op_le_w), - (And, "and", Binary, 77, 0x20, opc_and, op_and), - (Or, "or", Binary, 77, 0x21, opc_or, op_or), - (Xor, "xor", Binary, 77, 0x22, opc_xor, op_xor), + (SignExtendB, "signextend_b", BinaryE, 109, 0x37, opc_signextend_b, op_signextend_b), + (SignExtendH, "signextend_h", BinaryE, 109, 0x38, opc_signextend_h, op_signextend_h), + (SignExtendW, "signextend_w", BinaryE, 109, 0x39, opc_signextend_w, op_signextend_w), + (Add, "add", Binary, 77, 0x0c, opc_add, op_add), + (AddW, "add_w", Binary, 77, 0x2c, opc_add_w, op_add_w), + (Sub, "sub", Binary, 77, 0x0d, opc_sub, op_sub), + (SubW, "sub_w", Binary, 77, 0x2d, opc_sub_w, op_sub_w), + (Sll, "sll", BinaryE, 109, 0x31, opc_sll, op_sll), + (SllW, "sll_w", BinaryE, 109, 0x34, opc_sll_w, op_sll_w), + (Sra, "sra", BinaryE, 109, 0x33, opc_sra, op_sra), + (Srl, "srl", BinaryE, 109, 0x32, opc_srl, op_srl), + (SraW, "sra_w", BinaryE, 109, 0x36, opc_sra_w, op_sra_w), + (SrlW, "srl_w", BinaryE, 109, 0x35, opc_srl_w, op_srl_w), + (Eq, "eq", Binary, 77, 0x0b, opc_eq, op_eq), + (EqW, "eq_w", Binary, 77, 0x2b, opc_eq_w, op_eq_w), + (Ltu, "ltu", Binary, 77, 0x08, opc_ltu, op_ltu), + (Lt, "lt", Binary, 77, 0x09, opc_lt, op_lt), + (LtuW, "ltu_w", Binary, 77, 0x28, opc_ltu_w, op_ltu_w), + (LtW, "lt_w", Binary, 77, 0x29, opc_lt_w, op_lt_w), + (Leu, "leu", Binary, 77, 0x0e, opc_leu, op_leu), + (Le, "le", Binary, 77, 0x0f, opc_le, op_le), + (LeuW, "leu_w", Binary, 77, 0x2e, opc_leu_w, op_leu_w), + (LeW, "le_w", Binary, 77, 0x2f, opc_le_w, op_le_w), + (And, "and", Binary, 77, 0x10, opc_and, op_and), + (Or, "or", Binary, 77, 0x11, opc_or, op_or), + (Xor, "xor", Binary, 77, 0x12, opc_xor, op_xor), (Mulu, "mulu", ArithAm32, 97, 0xb0, opc_mulu, op_mulu), (Muluh, "muluh", ArithAm32, 97, 0xb1, opc_muluh, op_muluh), (Mulsuh, "mulsuh", ArithAm32, 97, 0xb3, opc_mulsuh, op_mulsuh), @@ -275,14 +275,14 @@ define_ops! { (RemuW, "remu_w", ArithA32, 136, 0xbd, opc_remu_w, op_remu_w), (DivW, "div_w", ArithA32, 136, 0xbe, opc_div_w, op_div_w), (RemW, "rem_w", ArithA32, 136, 0xbf, opc_rem_w, op_rem_w), - (Minu, "minu", Binary, 77, 0x09, opc_minu, op_minu), - (Min, "min", Binary, 77, 0x0a, opc_min, op_min), - (MinuW, "minu_w", Binary, 77, 0x19, opc_minu_w, op_minu_w), - (MinW, "min_w", Binary, 77, 0x1a, opc_min_w, op_min_w), - (Maxu, "maxu", Binary, 77, 0x0b, opc_maxu, op_maxu), - (Max, "max", Binary, 77, 0x0c, opc_max, op_max), - (MaxuW, "maxu_w", Binary, 77, 0x1b, opc_maxu_w, op_maxu_w), - (MaxW, "max_w", Binary, 77, 0x1c, opc_max_w, op_max_w), + (Minu, "minu", Binary, 77, 0x02, opc_minu, op_minu), + (Min, "min", Binary, 77, 0x03, opc_min, op_min), + (MinuW, "minu_w", Binary, 77, 0x22, opc_minu_w, op_minu_w), + (MinW, "min_w", Binary, 77, 0x23, opc_min_w, op_min_w), + (Maxu, "maxu", Binary, 77, 0x04, opc_maxu, op_maxu), + (Max, "max", Binary, 77, 0x05, opc_max, op_max), + (MaxuW, "maxu_w", Binary, 77, 0x24, opc_maxu_w, op_maxu_w), + (MaxW, "max_w", Binary, 77, 0x25, opc_max_w, op_max_w), (Keccak, "keccak", Keccak, 77, 0xf1, opc_keccak, op_keccak), (PubOut, "pubout", PubOut, 77, 0x30, opc_pubout, op_pubout), // TODO: New type } diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 9a796335..9098a62b 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -44,7 +44,7 @@ impl Pilout { air_group.add_air(Some("ArithTable"), 128); air_group.add_air(Some("ArithRangeTable"), 4194304); air_group.add_air(Some("Binary"), 2097152); - air_group.add_air(Some("BinaryTable"), 4194304); + air_group.add_air(Some("BinaryTable"), 8388608); air_group.add_air(Some("BinaryExtension"), 2097152); air_group.add_air(Some("BinaryExtensionTable"), 4194304); air_group.add_air(Some("SpecifiedRanges"), 16777216); diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index da3c9ff1..4b09b133 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -24,7 +24,7 @@ trace!(ArithRangeTableRow, ArithRangeTableTrace { }); trace!(BinaryRow, BinaryTrace { - m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, cout: F, result_is_a: F, use_last_carry_mode32: F, use_last_carry_mode64: F, m_op_or_ext: F, free_in_a_or_c: [F; 4], free_in_b_or_zero: [F; 4], multiplicity: F, main_step: F, + m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, has_initial_carry: F, cout: F, result_is_a: F, use_last_carry_mode32: F, use_last_carry_mode64: F, m_op_or_ext: F, free_in_a_or_c: [F; 4], free_in_b_or_zero: [F; 4], multiplicity: F, main_step: F, }); trace!(BinaryTableRow, BinaryTableTrace { diff --git a/state-machines/arith/Cargo.toml b/state-machines/arith/Cargo.toml index 7c0859b8..2e17605d 100644 --- a/state-machines/arith/Cargo.toml +++ b/state-machines/arith/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" zisk-core = { path = "../../core" } zisk-pil = { path="../../pil" } sm-common = { path = "../common" } +sm-binary = { path = "../binary" } p3-field = { workspace=true } proofman-common = { workspace = true } diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 1a830cb5..03c2cfc8 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -7,8 +7,6 @@ require "arith_range_table.pil" // full mul_64 full_32 mul_32 // TOTAL 88 77 57 44 -const int OP_LT_ABS = 0x9F; - airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_result = 0) { const int CHUNK_SIZE = 2**16; @@ -275,11 +273,14 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu bus_res0, bus_res1, div_by_zero /*+ div_overflow*/], mul: multiplicity); - // TODO: remainder check - // lookup_assumes(operation_bus_id, [debug_main_step, signed * (OP_LT_ABS - OP_LT) + OP_LT, - // (d[0] + CHUNK_SIZE * d[1]), (d[2] + CHUNK_SIZE * d[3]) + m32 * nr * 0xFFFFFFFF, - // (b[0] + CHUNK_SIZE * b[1]), (b[2] + CHUNK_SIZE * b[3]) + m32 * nb * 0xFFFFFFFF, - // 1, 0, 1], sel: div); + // Check that remainder (d) is lower than divisor (b) when division is performed + // Specifically, we ensure that 0 <= |d| < |b| + lookup_assumes(operation_bus_id, [debug_main_step, + (1 - nr) * (1 - nb) * LTU_OP + nr * (1 - nb) * LT_ABS_NP_OP + (1 - nr) * nb * LT_ABS_PN_OP + nr * nb * GT_OP, + (d[0] + CHUNK_SIZE * d[1]), (d[2] + CHUNK_SIZE * d[3]) + m32 * nr * 0xFFFFFFFF, // remainder + (b[0] + CHUNK_SIZE * b[1]), (b[2] + CHUNK_SIZE * b[3]) + m32 * nb * 0xFFFFFFFF, // divisor + 1, 0, + 1], sel: div * (1 - div_by_zero)); for (int index = 0; index < length(carry); ++index) { arith_range_table_assumes(ARITH_RANGE_CARRY, carry[index]); // TODO: review carry range diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index d14d1b7b..de66eb32 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -3,15 +3,16 @@ use std::sync::{ Arc, }; -use p3_field::Field; +use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; +use sm_binary::BinarySM; use zisk_core::ZiskRequiredOperation; use zisk_pil::{ARITH_AIR_IDS, ARITH_RANGE_TABLE_AIR_IDS, ARITH_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{ArithFullSM, ArithRangeTableSM, ArithTableSM}; #[allow(dead_code)] -pub struct ArithSM { +pub struct ArithSM { // Count of registered predecessors registered_predecessors: AtomicU32, @@ -20,8 +21,8 @@ pub struct ArithSM { arith_range_table_sm: Arc>, } -impl ArithSM { - pub fn new(wcm: Arc>) -> Arc { +impl ArithSM { + pub fn new(wcm: Arc>, binary_sm: Arc>) -> Arc { let arith_table_sm = ArithTableSM::new(wcm.clone(), ZISK_AIRGROUP_ID, ARITH_TABLE_AIR_IDS); let arith_range_table_sm = ArithRangeTableSM::new(wcm.clone(), ZISK_AIRGROUP_ID, ARITH_RANGE_TABLE_AIR_IDS); @@ -29,6 +30,7 @@ impl ArithSM { wcm.clone(), arith_table_sm.clone(), arith_range_table_sm.clone(), + binary_sm, ZISK_AIRGROUP_ID, ARITH_AIR_IDS, ); @@ -60,4 +62,4 @@ impl ArithSM { } } -impl WitnessComponent for ArithSM {} +impl WitnessComponent for ArithSM {} diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 954a4030..99d71a65 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -7,30 +7,36 @@ use crate::{ ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithTableInputs, ArithTableSM, }; use log::info; -use p3_field::Field; +use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; +use sm_binary::{BinarySM, GT_OP, LTU_OP, LT_ABS_NP_OP, LT_ABS_PN_OP}; use sm_common::i64_to_u64_field; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::*; -pub struct ArithFullSM { +const CHUNK_SIZE: u64 = 0x10000; +const EXTENSION: u64 = 0xFFFFFFFF; + +pub struct ArithFullSM { wcm: Arc>, // Count of registered predecessors registered_predecessors: AtomicU32, - // Inputs + // Secondary State Machines arith_table_sm: Arc>, arith_range_table_sm: Arc>, + binary_sm: Arc>, } -impl ArithFullSM { +impl ArithFullSM { const MY_NAME: &'static str = "Arith "; pub fn new( wcm: Arc>, arith_table_sm: Arc>, arith_range_table_sm: Arc>, + binary_sm: Arc>, airgroup_id: usize, air_ids: &[usize], ) -> Arc { @@ -39,6 +45,7 @@ impl ArithFullSM { registered_predecessors: AtomicU32::new(0), arith_table_sm, arith_range_table_sm, + binary_sm, }; let arith_full_sm = Arc::new(arith_full_sm); @@ -46,6 +53,7 @@ impl ArithFullSM { arith_full_sm.arith_table_sm.register_predecessor(); arith_full_sm.arith_range_table_sm.register_predecessor(); + arith_full_sm.binary_sm.register_predecessor(); arith_full_sm } @@ -58,6 +66,7 @@ impl ArithFullSM { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { self.arith_table_sm.unregister_predecessor(); self.arith_range_table_sm.unregister_predecessor(); + self.binary_sm.unregister_predecessor(); } } pub fn prove_instance(&self, input: Vec, prover_buffer: &mut [F]) { @@ -80,6 +89,7 @@ impl ArithFullSM { let mut traces = ArithTrace::::map_buffer(prover_buffer, num_rows, 0).unwrap(); let mut aop = ArithOperation::new(); + let mut binary_inputs = Vec::new(); for (irow, input) in input.iter().enumerate() { aop.calculate(input.opcode, input.a, input.b); let mut t: ArithRow = Default::default(); @@ -179,6 +189,42 @@ impl ArithFullSM { aop.d[2] + (aop.d[3] << 16) }); traces[irow] = t; + + // If the operation is a division, then use the binary component + // to check that the remainer is lower than the divisor + if aop.div && !aop.div_by_zero { + let opcode = match (aop.nr, aop.nb) { + (false, false) => LTU_OP, + (false, true) => LT_ABS_PN_OP, + (true, false) => LT_ABS_NP_OP, + (true, true) => GT_OP, + }; + + let extension = match (aop.m32, aop.nr, aop.nb) { + (false, _, _) => (0, 0), + (true, false, false) => (0, 0), + (true, false, true) => (0, EXTENSION), + (true, true, false) => (EXTENSION, 0), + (true, true, true) => (EXTENSION, EXTENSION), + }; + + // TODO: We dont need to "glue" the d,b chunks back, we can use the aop API to do + // this! + let operation = ZiskRequiredOperation { + step: input.step, + opcode, + a: aop.d[0] + + CHUNK_SIZE * aop.d[1] + + CHUNK_SIZE.pow(2) * (aop.d[2] + extension.0) + + CHUNK_SIZE.pow(3) * aop.d[3], + b: aop.b[0] + + CHUNK_SIZE * aop.b[1] + + CHUNK_SIZE.pow(2) * (aop.b[2] + extension.1) + + CHUNK_SIZE.pow(3) * aop.b[3], + }; + + binary_inputs.push(operation); + } } timer_stop_and_log_trace!(ARITH_TRACE); @@ -219,7 +265,14 @@ impl ArithFullSM { timer_start_trace!(ARITH_RANGE_TABLE); self.arith_range_table_sm.process_slice(&range_table_inputs); timer_stop_and_log_trace!(ARITH_RANGE_TABLE); + + if !binary_inputs.is_empty() { + timer_start_trace!(ARITH_BINARY); + info!("{}: ··· calling binary_sm", Self::MY_NAME); + self.binary_sm.prove(binary_inputs.as_slice(), false); + timer_stop_and_log_trace!(ARITH_BINARY); + } } } -impl WitnessComponent for ArithFullSM {} +impl WitnessComponent for ArithFullSM {} diff --git a/state-machines/binary/pil/binary.pil b/state-machines/binary/pil/binary.pil index de055894..8d8d5634 100644 --- a/state-machines/binary/pil/binary.pil +++ b/state-machines/binary/pil/binary.pil @@ -4,61 +4,51 @@ require "std_lookup.pil" /* List 64-bit operations: - name │ op │ m_op │ carry │ use_last_carry │ NOTES - ────────┼──────────┼──────────┼───────┼────────────────┼─────────────────────────────────── - ADD │ 0x02 │ 0x02 │ X │ │ - SUB │ 0x03 │ 0x03 │ X │ │ - LTU │ 0x04 │ 0x04 │ X │ X │ - LT │ 0x05 │ 0x05 │ X │ X │ - LEU │ 0x06 │ 0x06 │ X │ X │ - LE │ 0x07 │ 0x07 │ X │ X │ - EQ │ 0x08 │ 0x08 │ X │ X │ - MINU │ 0x09 │ 0x09 │ X │ │ - MIN │ 0x0a │ 0x0a │ X │ │ - MAXU │ 0x0b │ 0x0b │ X │ │ - MAX │ 0x0c │ 0x0c │ X │ │ - AND │ 0x20 │ 0x20 │ │ │ - OR │ 0x21 │ 0x21 │ │ │ - XOR │ 0x22 │ 0x22 │ │ │ - ────────┼──────────┼──────────┼───────┼────────────────┼─────────────────────────────────── + name │ op │ m_op │ has_initial_carry │ carry │ use_last_carry │ ZisK OP │ Notes │ + ───────────┼──────────┼──────────┼───────────────────┼───────┼────────────────┼─────────┼────────────────────────────────────────────────────┼ + MINU │ 0x02 │ 0x02 │ │ X │ │ X │ │ + MIN │ 0x03 │ 0x03 │ │ X │ │ X │ │ + MAXU │ 0x04 │ 0x04 │ │ X │ │ X │ │ + MAX │ 0x05 │ 0x05 │ │ X │ │ X │ │ + LT_ABS_NP │ 0x06 │ 0x06 │ X │ X │ X │ │ This operation is used by the arithmetic component │ + LT_ABS_PN │ 0x07 │ 0x07 │ X │ X │ X │ │ This operation is used by the arithmetic component │ + LTU │ 0x08 │ 0x08 │ │ X │ X │ X │ │ + LT │ 0x09 │ 0x09 │ │ X │ X │ X │ │ + GT │ 0x0a │ 0x0a │ │ X │ X │ │ This operation is used by the arithmetic component │ + EQ │ 0x0b │ 0x0b │ │ X │ X │ X │ │ + ADD │ 0x0c │ 0x0c │ │ X │ │ X │ │ + SUB │ 0x0d │ 0x0d │ │ X │ │ X │ │ + LEU │ 0x0e │ 0x0e │ │ X │ X │ X │ │ + LE │ 0x0f │ 0x0f │ │ X │ X │ X │ │ + AND │ 0x10 │ 0x10 │ │ │ │ X │ │ + OR │ 0x11 │ 0x11 │ │ │ │ X │ │ + XOR │ 0x12 │ 0x12 │ │ │ │ X │ │ + ───────────┼──────────┼──────────┼───────────────────┼───────┼────────────────┼─────────┼────────────────────────────────────────────────────┼ List 32-bit operations: - name │ op │ m_op │ carry │ use_last_carry │ NOTES - ────────┼──────────┼──────────┼───────┼────────────────┼─────────────────────────────────── - ADD_W │ 0x12 │ 0x02 │ X │ │ - SUB_W │ 0x13 │ 0x03 │ X │ │ - LTU_W │ 0x14 │ 0x04 │ X │ X │ - LT_W │ 0x15 │ 0x05 │ X │ X │ - LEU_W │ 0x16 │ 0x06 │ X │ X │ - LE_W │ 0x17 │ 0x07 │ X │ X │ - EQ_W │ 0x18 │ 0x08 │ X │ X │ - MINU_W │ 0x19 │ 0x09 │ X │ │ - MIN_W │ 0x1a │ 0x0a │ X │ │ - MAXU_W │ 0x1b │ 0x0b │ X │ │ - MAX_W │ 0x1c │ 0x0c │ X │ │ - ────────┼──────────┼──────────┼───────┼────────────────┼─────────────────────────────────── - - Opcodes: - --------------------------------------- - expr op = m_op + 16*mode32 - - mode32 64bits 32bits m_op op - 0/1 ADD ADD_W 0x02 (0x02,0x12) - 0/1 SUB SUB_W 0x03 (0x03,0x13) - 0/1 LTU LTU_W 0x04 (0x04,0x14) - 0/1 LT LT_W 0x05 (0x05,0x15) - 0/1 LEU LEU_W 0x06 (0x06,0x16) - 0/1 LE LE_W 0x07 (0x07,0x17) - 0/1 EQ EQ_W 0x08 (0x08,0x18) - 0/1 MINU MINU_W 0x09 (0x09,0x19) - 0/1 MIN MIN_W 0x0a (0x0a,0x1a) - 0/1 MAXU MAXU_W 0x0b (0x0b,0x1b) - 0/1 MAX MAX_W 0x0c (0x0c,0x1c) - 0/1 AND 0x20 0x20 - 0 OR 0x21 0x21 - 0 XOR 0x22 0x22 + name │ op │ m_op │ has_initial_carry │ carry │ use_last_carry │ ZisK OP │ + ───────────┼──────────┼──────────┼───────────────────┼───────┼────────────────┼─────────│ + MINU_W │ 0x22 │ 0x02 │ │ X │ │ X │ + MIN_W │ 0x23 │ 0x03 │ │ X │ │ X │ + MAXU_W │ 0x24 │ 0x04 │ │ X │ │ X │ + MAX_W │ 0x25 │ 0x05 │ │ X │ │ X │ + LTU_W │ 0x28 │ 0x08 │ │ X │ X │ X │ + LT_W │ 0x29 │ 0x09 │ │ X │ X │ X │ + EQ_W │ 0x2b │ 0x0b │ │ X │ X │ X │ + ADD_W │ 0x2c │ 0x0c │ │ X │ │ X │ + SUB_W │ 0x2d │ 0x0d │ │ X │ │ X │ + LEU_W │ 0x2e │ 0x0e │ │ X │ X │ X │ + LE_W │ 0x2f │ 0x0f │ │ X │ X │ X │ + ───────────┼──────────┼──────────┼───────────────────┼───────┼────────────────┼─────────│ + + Note: op = m_op + 0x20*mode32 */ +const int LT_ABS_NP_OP = 0x06; +const int LT_ABS_PN_OP = 0x07; +const int LTU_OP = 0x08; +const int GT_OP = 0x0a; + airtemplate Binary(const int N = 2**21, const int operation_bus_id) { // Default values const int bits = 64; @@ -70,16 +60,17 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id) { const int input_chunk_bytes = bytes / input_chunks; // Primary columns - col witness m_op; // micro operation code of the binary table (e.g. add) - col witness mode32; // 1 if the operation is 32 bits, 0 otherwise - col witness free_in_a[bytes]; // input1 - col witness free_in_b[bytes]; // input2 - col witness free_in_c[bytes]; // output - col witness carry[bytes]; // bytes chunks carries [0,cout:0],[cin:0,cout:1],...,[cin:bytes-2,cout:bytes-1] + col witness m_op; // micro operation code of the binary table (e.g. add) + col witness mode32; // 1 if the operation is 32 bits, 0 otherwise + col witness free_in_a[bytes]; // input1 + col witness free_in_b[bytes]; // input2 + col witness free_in_c[bytes]; // output + col witness carry[bytes]; // bytes chunks carries [0,cout:0],[cin:0,cout:1],...,[cin:bytes-2,cout:bytes-1] // Secondary columns - col witness use_last_carry; // 1 if the operation uses the last carry as its result - col witness op_is_min_max; // 1 if the operation is any of the MIN/MAX operations + col witness use_last_carry; // 1 if the operation uses the last carry as its result + col witness op_is_min_max; // 1 if the operation is any of the MIN/MAX operations + col witness has_initial_carry; // 1 if the operation has an initial carry const expr mode64 = 1 - mode32; const expr cout32 = carry[half_bytes-1]; @@ -87,7 +78,6 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id) { use_last_carry * (1 - use_last_carry) === 0; op_is_min_max * (1 - op_is_min_max) === 0; - cout32*(1 - cout32) === 0; cout64*(1 - cout64) === 0; // Auxiliary columns (primarily used to optimize lookups, but can be substituted with expressions) @@ -118,7 +108,7 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id) { where last indicates whether the byte is the last one in the operation */ - lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*result_is_a]); + lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], has_initial_carry*INITIAL_CARRY_LT_ABS, free_in_c[0], carry[0] + 2*op_is_min_max + 4*result_is_a]); // More auxiliary columns col witness m_op_or_ext; @@ -178,7 +168,7 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id) { c[0] += use_last_carry * cout; c[input_chunks - 1] -= use_last_carry * cout * factor; - expr op = m_op + 16 * mode32; + expr op = m_op + 0x20 * mode32; col witness multiplicity; col witness main_step; diff --git a/state-machines/binary/pil/binary_extension.pil b/state-machines/binary/pil/binary_extension.pil index 74a411b4..c5d9fc4a 100644 --- a/state-machines/binary/pil/binary_extension.pil +++ b/state-machines/binary/pil/binary_extension.pil @@ -8,15 +8,15 @@ List: ┼────────┼────────┼──────────┼ │ name │ bits │ op │ ┼────────┼────────┼──────────┼ - │ SLL │ 64 │ 0x0d │ - │ SRL │ 64 │ 0x0e │ - │ SRA │ 64 │ 0x0f │ - │ SLL_W │ 32 │ 0x1d │ - │ SRL_W │ 32 │ 0x1e │ - │ SRA_W │ 32 │ 0x1f │ - │ SE_B │ 32 │ 0x23 │ - │ SE_H │ 32 │ 0x24 │ - │ SE_W │ 32 │ 0x25 │ + │ SLL │ 64 │ 0x31 │ + │ SRL │ 64 │ 0x32 │ + │ SRA │ 64 │ 0x33 │ + │ SLL_W │ 32 │ 0x34 │ + │ SRL_W │ 32 │ 0x35 │ + │ SRA_W │ 32 │ 0x36 │ + │ SE_B │ 32 │ 0x37 │ + │ SE_H │ 32 │ 0x38 │ + │ SE_W │ 32 │ 0x39 │ ┼────────┼────────┼──────────┼ Examples: diff --git a/state-machines/binary/pil/binary_extension_table.pil b/state-machines/binary/pil/binary_extension_table.pil index d346ff97..35e0ad35 100644 --- a/state-machines/binary/pil/binary_extension_table.pil +++ b/state-machines/binary/pil/binary_extension_table.pil @@ -3,15 +3,15 @@ require "std_lookup.pil" // Operations Table: // Running Total -// SLL (OP:0x0d) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^19 -// SRL (OP:0x0e) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 -// SRA (OP:0x0f) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 + 2^19 -// SLL_W (OP:0x1d) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 -// SRL_W (OP:0x1e) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^19 -// SRA_W (OP:0x1f) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^20 -// SE_B (OP:0x23) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 + 2^11 -// SE_H (OP:0x24) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 2^12 -// SE_W (OP:0x25) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 2^12 + 2^11 => 2^22 +// SLL (OP:0x31) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^19 +// SRL (OP:0x32) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 +// SRA (OP:0x33) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 + 2^19 +// SLL_W (OP:0x34) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 +// SRL_W (OP:0x35) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^19 +// SRA_W (OP:0x36) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^20 +// SE_B (OP:0x37) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 + 2^11 +// SE_H (OP:0x38) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 2^12 +// SE_W (OP:0x39) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 2^12 + 2^11 => 2^22 const int BINARY_EXTENSION_TABLE_ID = 124; @@ -49,18 +49,25 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = #pragma timer tt start #pragma timer t1 start - col fixed A = [0..255]...; // Input A (8 bits) - col fixed OFFSET = [0:P2_8..(bytes-1):P2_8]...; // Offset (3 bits) + // Input A (8 bits) + col fixed A = [0..255]...; - col fixed B = [[0:P2_11..255:P2_11]:6, // Input B (8 bits) - 0:(P2_11*3)]...; + // Offset (3 bits) + col fixed OFFSET = [0:P2_8..(bytes-1):P2_8]...; - col fixed OP = [0x0d:P2_19, 0x0e:P2_19, 0x0f:P2_19, // SLL, SRL, SRA - 0x1d:P2_19, 0x1e:P2_19, 0x1f:P2_19, // SLL_W, SRL_W, SRA_W - 0x23:P2_11, 0x24:P2_11, 0x25:P2_11]...; // SE_B, SE_H, SE_W + // Input B (8 bits) + col fixed B = [[0:P2_11..255:P2_11]:6, // SLL, SRL, SRA, SLL_W, SRL_W, SRA_W + 0:(P2_11*3)]...; // SE_B, SE_H, SE_W - col fixed OP_IS_SHIFT = [1:(P2_19*6), 0:(P2_11*3)]...; + // Operation is shift (fixed values) + col fixed OP_IS_SHIFT = [1:(P2_19*6), // SLL, SRL, SRA, SLL_W, SRL_W, SRA_W + 0:(P2_11*3)]...; // SE_B, SE_H, SE_W + + // Operation opcode (fixed values) + col fixed OP = [0x31:P2_19, 0x32:P2_19, 0x33:P2_19, // SLL, SRL, SRA + 0x34:P2_19, 0x35:P2_19, 0x36:P2_19, // SLL_W, SRL_W, SRA_W + 0x37:P2_11, 0x38:P2_11, 0x39:P2_11]...; // SE_B, SE_H, SE_W #pragma timer t1 end #pragma timer t2 start @@ -76,13 +83,13 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = int _out = 0; const int _a = a << (8*offset); switch (op) { - case 0x0d: // SLL + case 0x31: // SLL _out = _a << (b & LS_6_BITS); - case 0x0e: // SRL + case 0x32: // SRL _out = _a >> (b & LS_6_BITS); - case 0x0f: { // SRA + case 0x33: { // SRA const int _b = b & LS_6_BITS; _out = _a >> _b; if (offset == 7) { @@ -93,7 +100,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } } } - case 0x1d: // SLL_W + case 0x34: // SLL_W if (offset >= 4) { // last most significant bytes are ignored because it's 32-bit operation _out = 0; @@ -104,7 +111,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } } - case 0x1e: // SRL_W + case 0x35: // SRL_W if (offset >= 4) { // last most significant bytes are ignored because it's 32-bit operation _out = 0; @@ -115,7 +122,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } } - case 0x1f: // SRA_W + case 0x36: // SRA_W if (offset >= 4) { // last most significant bytes are ignored because it's 32-bit operation _out = 0; @@ -131,7 +138,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } } - case 0x23: // SE_B + case 0x37: // SE_B if (offset == 0) { // the most significant bit of first byte determines the sign extend _out = (a & SIGN_BYTE) ? a | SE_MASK_8 : a @@ -140,7 +147,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = _out = 0; } - case 0x24: // SE_H + case 0x38: // SE_H if (offset == 0) { // fist byte not define the sign extend, but participate of result _out = a; @@ -152,7 +159,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = _out = 0; } - case 0x25: // SE_W + case 0x39: // SE_W if (offset <= 3) { _out = _a; if (offset == 3) { diff --git a/state-machines/binary/pil/binary_table.pil b/state-machines/binary/pil/binary_table.pil index 316ef9a0..30b26e44 100644 --- a/state-machines/binary/pil/binary_table.pil +++ b/state-machines/binary/pil/binary_table.pil @@ -3,30 +3,37 @@ require "std_lookup.pil" // PIL Binary Operations Table used by Binary // Running Total -// MINU/MINU_W (OP:0x09) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^19 -// MIN/MIN_W (OP:0x0a) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20 -// MAXU/MAXU_W (OP:0x0b) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20 + 2^19 -// MAX/MAX_W (OP:0x0c) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^21 -// LTU/LTU_W (OP:0x04) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^18 -// LT/LT_W (OP:0x05) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^19 -// EQ/EQ_W (OP:0x08) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^19 + 2^18 -// ADD/ADD_W (OP:0x02) ** 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 -// SUB/SUB_W (OP:0x03) ** 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^18 -// LEU/LEU_W (OP:0x06) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^18 + 2^17 -// LE/LE_W (OP:0x07) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19 -// AND/AND_W (OP:0x20) 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19 + 2^17 -// OR/OR_W (OP:0x21) 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19 + 2^18 -// XOR/XOR_W (OP:0x22) 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19 + 2^18 + 2^17 -// EXT_32 (OP:0x23) 2^8 (A) x 2^1 (CIN) x 2^2 (FLAGS) = 2^16 | 2^21 + 2^20 + 2^19 + 2^18 + 2^17 + 2^11 => < 2^22 +// MINU/MINU_W (OP:0x02) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^19 +// MIN/MIN_W (OP:0x03) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20 +// MAXU/MAXU_W (OP:0x04) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20 + 2^19 +// MAX/MAX_W (OP:0x05) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^21 +// LT_ABS_NP (OP:0x06) * 2^16 (AxB) x 2^1 (LAST) x 2^2 (CIN) = 2^19 | 2^21 + 2^19 +// LT_ABS_PN (OP:0x07) * 2^16 (AxB) x 2^1 (LAST) x 2^2 (CIN) = 2^19 | 2^21 + 2^20 +// LTU/LTU_W (OP:0x08) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^18 +// LT/LT_W (OP:0x09) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^19 +// GT (OP:0x0a) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^19 + 2^18 +// EQ/EQ_W (OP:0x0b) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^22 +// ADD/ADD_W (OP:0x0c) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^22 + 2^18 +// SUB/SUB_W (OP:0x0d) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^22 + 2^19 +// LEU/LEU_W (OP:0x0e) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^19 + 2^17 +// LE/LE_W (OP:0x0f) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^19 + 2^18 +// AND/AND_W (OP:0x10) ** 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^19 + 2^18 + 2^17 +// OR/OR_W (OP:0x11) ** 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^20 +// XOR/XOR_W (OP:0x12) ** 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^20 + 2^17 +// EXT_32 (OP:0x13) *** 2^8 (A) x 2^1 (CIN) x 2^2 (FLAGS) = 2^11 | 2^22 + 2^20 + 2^17 + 2^11 => < 2^23 // -------------------------------------------------------------------------------------------------------------------------- -// (*) Use carry -// (**) Do not use last indicator, but it is used for simplicity of the lookup -// Note: EXT_32 is the only unary operation +// (*) Uses the carry of the last byte of the result (use_last_carry) +// (**) The op do not use LAST, but the binary does so we need to consider it +// (***) The op do not use CIN, but the binary does so we need to consider it +// Note: EXT_32 is the only unary operation which is not a ZisK OP but it is used to prove the rest -const int EXT_32_OP = 0x23; const int BINARY_TABLE_ID = 125; -airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) { +const int EXT_32_OP = 0x13; + +const int INITIAL_CARRY_LT_ABS = 0x02; + +airtemplate BinaryTable(const int N = 2**23, const int disable_fixed = 0) { #pragma memory m1 start col witness multiplicity; @@ -40,39 +47,49 @@ airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) { return; } - if (N < 2**22) { - error(`N must be at least 2^22, but N=${N} was provided`); + if (N < 2**23) { + error(`N must be at least 2^23, but N=${N} was provided`); } #pragma timer tt start #pragma timer t1 start - col fixed A = [0..255]...; // Input A (8 bits) + // Input A (8 bits) + col fixed A = [0..255]...; - col fixed B = [[0:P2_8..255:P2_8]:62,0:P2_11]...; // Input B (8 bits) + // Input B (<=8 bits) + col fixed B = [[0:P2_8..255:P2_8]:82, // 82 = 4*8 + 2*8 + 4*4 + 2*4 + 2*2 + 3*2 + 0:P2_11]...; // B is 0 for EXT_32 - col fixed LAST = [[0:P2_16, 1:P2_16]:(4*4), // Indicator of the last byte (1 bit) - [0:P2_16, 1:P2_16]:(5*2), - [0:P2_16, 1:P2_16]:5, + // Indicator of the last byte (<=1 bit) + col fixed LAST = [[0:P2_16, 1:P2_16]:(4*4), // MINU,MIN,MAXU,MAX + [0:P2_16, 1:P2_16]:(2*4), // LT_ABS_NP,LT_ABS_PN + [0:P2_16, 1:P2_16]:(4*2), // LTU,LT,GT,EQ + [0:P2_16, 1:P2_16]:(2*2), // ADD,SUB + [0:P2_16, 1:P2_16]:2, // LEU,LE + [0:P2_16, 1:P2_16]:3, // AND,OR,XOR 0:P2_11]...; - col fixed CIN = [[0:P2_17, 1:P2_17]:(4*2), // Input carry (1 bit) - [0:P2_17, 1:P2_17]:5, - 0:(P2_17*5), - [0:P2_8, 1:P2_8]:4]...; - - col fixed OP = [0x09:P2_19, 0x0a:P2_19, 0x0b:P2_19, 0x0c:P2_19, // MINU,MIN,MAXU,MAX - 0x04:P2_18, 0x05:P2_18, 0x08:P2_18, // LTU,LT,EQ - 0x02:P2_18, 0x03:P2_18, // ADD,SUB - 0x06:P2_17, 0x07:P2_17, // LEU,LE - 0x20:P2_17, 0x21:P2_17, 0x22:P2_17, // AND,OR,XOR - EXT_32_OP:P2_11]...; // EXT_32 - - // NOTE: MINU/MINU_W, MIN/MIN_W, MAXU/MAXU_W, MAX/MAX_W has double size because - // the result_is_a is 0 in the first half and 1 in the second half. - - const int TABLE_SIZE = P2_19 * 4 + P2_18 * 5 + P2_17 * 6; - const int TABLE_BASE_EXT32 = P2_16 * 62; + // Input carry (<=2 bits) + col fixed CIN = [[0:P2_17, 1:P2_17]:(4*2), // MINU,MIN,MAXU,MAX + [0:P2_17..3:P2_17]:2, // LT_ABS_NP,LT_ABS_PN + [0:P2_17, 1:P2_17]:4, // LTU,LT,GT,EQ + [0:P2_17, 1:P2_17]:2, // ADD,SUB + 0:(P2_17*2), // LEU,LE + 0:(P2_17*3), // AND,OR,XOR + [0:P2_8, 1:P2_8]:4]...; // EXT_32 + + // Operation opcode (fixed values) + col fixed OP = [0x02:P2_19, 0x03:P2_19, 0x04:P2_19, 0x05:P2_19, // MINU,MIN,MAXU,MAX + 0x06:P2_19, 0x07:P2_19, // LT_ABS_NP,LT_ABS_PN + 0x08:P2_18, 0x09:P2_18, 0x0a:P2_18, 0x0b:P2_18, // LTU,LT,GT,EQ + 0x0c:P2_18, 0x0d:P2_18, // ADD,SUB + 0x0e:P2_17, 0x0f:P2_17, // LEU,LE + 0x10:P2_17, 0x11:P2_17, 0x12:P2_17, // AND,OR,XOR + 0x13:P2_11]...; // EXT_32 + + const int TABLE_SIZE = P2_19 * 6 + P2_18 * 6 + P2_17 * 5 + P2_11; + const int TABLE_BASE_EXT32 = P2_16 * 82; #pragma timer t1 end #pragma timer t2 start @@ -90,16 +107,102 @@ airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) { int index = i % TABLE_SIZE; int result_is_a = index < P2_21 ? ((index >> 18) & 0x01) : 0; switch (op) { - case 0x02: // ADD,ADD_W - c = (cin + a + b) & 0xFF; - cout = plast ? 0 : (cin + a + b) >> 8; + case 0x02,0x03: // MINU,MINU_W,MIN,MIN_W + // cout = 1 indicates that a is lower than b + if (a < b) { + cout = 1; + } else if (a == b) { + cout = cin; + } - case 0x03: // SUB,SUB_W - sign = (a - cin) >= b ? 0 : 1; - c = 256 * sign + a - cin - b; - cout = plast ? 0 : sign; + if (result_is_a) { + c = a; + } else { + c = b; + } + + if (op == 0x03 && plast) { + if ((a & 0x80) != (b & 0x80)) { + cout = (a & 0x80) ? 1 : 0; + } + } + + op_is_min_max = 1; + + case 0x04,0x05: // MAXU,MAXU_W,MAX,MAX_W + // cout = 1 indicates that a is greater than b + if (a > b) { + cout = 1; + } else if (a == b) { + cout = cin; + } + + if (result_is_a) { + c = a; + } else { + c = b; + } + + if (op == 0x05 && plast) { + if ((a & 0x80) != (b & 0x80)) { + cout = (a & 0x80) ? 0 : 1; + } + } + op_is_min_max = 1; + + case 0x06: // LT_ABS_NP + // Both necessary carries are encoded by cin in binary as + // cin = 0bYX, + // where X is the carry of the LT operation and Y is + // the carry of the operation a ^ 0xFF + _cop + + // Decode the carries + const int _clt = cin & 0x01; + const int _cop = (cin & 0x02) >> 1; + + const int _a = (a ^ 0xFF) + _cop; // _cop should be 1 at the first byte and _a >> 8 at the rest + const int _b = b; + + if ((_a & 0xFF) < _b) { + cout = 1; + c = plast; + } else if ((_a & 0xFF) == _b) { + cout = _clt; + c = plast * _clt; + } + + // Encode the result carries + cout += 2*(_a >> 8); + + use_last_carry = plast; - case 0x04,0x05: // LTU,LTU_W,LT,LT_W + case 0x07: // LT_ABS_PN + // Both necessary carries are encoded by cin in binary as + // cin = 0bYX, + // where X is the carry of the LT operation and Y is + // the carry of the operation b ^ 0xFF + _cop + + // Decode the carries + const int _clt = cin & 0x01; + const int _cop = (cin & 0x02) >> 1; + + const int _a = a; + const int _b = (b ^ 0xFF) + _cop; // _cop should be 1 at the first byte and _b >> 8 at the rest + + if (_a < (_b & 0xFF)) { + cout = 1; + c = plast; + } else if (_a == (_b & 0xFF)) { + cout = _clt; + c = plast * _clt; + } + + // Encode the result carries + cout += 2*(_b >> 8); + + use_last_carry = plast; + + case 0x08,0x09: // LTU,LTU_W,LT,LT_W if (a < b) { cout = 1; c = plast; @@ -109,86 +212,69 @@ airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) { } // If the chunk is signed, then the result is the sign of a - if (op == 0x05 && plast && (a & 0x80) != (b & 0x80)) { + if (op == 0x09 && plast && (a & 0x80) != (b & 0x80)) { c = (a & 0x80) ? 1 : 0; cout = c; } use_last_carry = plast; - case 0x06,0x07: // LEU,LEU_W,LE,LE_W - if (a <= b) { + case 0x0a: // GT + if (a > b) { cout = 1; c = plast; + } else if (a == b) { + cout = cin; + c = plast * cin; } - if (op == 0x07 && plast && (a & 0x80) != (b & 0x80)) { - c = (a & 0x80) ? 1 : 0; + // The result is the sign of b + if (plast && (a & 0x80) != (b & 0x80)) { + c = (b & 0x80) ? 1 : 0; cout = c; } use_last_carry = plast; - case 0x08: // EQ,EQ_W + case 0x0b: // EQ,EQ_W if (a == b && !cin) c = plast; else cout = 1; if (plast) cout = 1 - cout; use_last_carry = plast; - case 0x09,0x0a: // MINU,MINU_W,MIN,MIN_W - // cout = 1 indicates that a is lower than b - if (a < b) { - cout = 1; - } else if (a == b) { - cout = cin; - } - - if (result_is_a) { - c = a; - } else { - c = b; - } - - if (op == 0x0a && plast) { - if ((a & 0x80) != (b & 0x80)) { - cout = (a & 0x80) ? 1 : 0; - } - } + case 0x0c: // ADD,ADD_W + c = (cin + a + b) & 0xFF; + cout = plast ? 0 : (cin + a + b) >> 8; - op_is_min_max = 1; + case 0x0d: // SUB,SUB_W + sign = (a - cin) >= b ? 0 : 1; + c = 256 * sign + a - cin - b; + cout = plast ? 0 : sign; - case 0x0b,0x0c: // MAXU,MAXU_W,MAX,MAX_W - // cout = 1 indicates that a is greater than b - if (a > b) { + case 0x0e,0x0f: // LEU,LEU_W,LE,LE_W + if (a <= b) { cout = 1; - } else if (a == b) { - cout = cin; + c = plast; } - if (result_is_a) { - c = a; - } else { - c = b; + if (op == 0x0f && plast && (a & 0x80) != (b & 0x80)) { + c = (a & 0x80) ? 1 : 0; + cout = c; } - if (op == 0x0c && plast) { - if ((a & 0x80) != (b & 0x80)) { - cout = (a & 0x80) ? 0 : 1; - } - } - op_is_min_max = 1; + use_last_carry = plast; - case 0x20: // AND + case 0x10: // AND c = a & b; - case 0x21: // OR + case 0x11: // OR c = a | b; - case 0x22: // XOR + case 0x12: // XOR c = a ^ b; - case 0x23: // EXT_32 + case 0x13: // EXT_32 c = (a & 0x80) ? 0xFF : 0x00; const int index_offset = (index - TABLE_BASE_EXT32) >> 9; op_is_min_max = index_offset & 0x01; diff --git a/state-machines/binary/src/binary.rs b/state-machines/binary/src/binary.rs index 01311cd6..dc995864 100644 --- a/state-machines/binary/src/binary.rs +++ b/state-machines/binary/src/binary.rs @@ -7,8 +7,6 @@ use crate::{BinaryBasicSM, BinaryBasicTableSM, BinaryExtensionSM, BinaryExtensio use p3_field::PrimeField; use pil_std_lib::Std; use proofman::{WitnessComponent, WitnessManager}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; use zisk_core::ZiskRequiredOperation; use zisk_pil::{ BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, BINARY_EXTENSION_TABLE_AIR_IDS, BINARY_TABLE_AIR_IDS, @@ -78,13 +76,8 @@ impl BinarySM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - /* as Provable>::prove( - self, - &[], - true, - scope, - );*/ - //self.threads_controller.wait_for_threads(); + // If there are remaining binary inputs, prove them + self.prove(&[], true); self.binary_basic_sm.unregister_predecessor(); self.binary_extension_sm.unregister_predecessor(); @@ -103,19 +96,15 @@ impl BinarySM { self.binary_extension_sm.prove_instance(operations, prover_buffer); } } -} - -impl WitnessComponent for BinarySM {} -impl Provable for BinarySM { - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { + pub fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool) { + // Split the operations into basic and extended operations let mut _inputs_basic = Vec::new(); let mut _inputs_extension = Vec::new(); let basic_operations = BinaryBasicSM::::operations(); let extension_operations = BinaryExtensionSM::::operations(); - // TODO Split the operations into basic and extended operations in parallel for operation in operations { if basic_operations.contains(&operation.opcode) { _inputs_basic.push(operation.clone()); @@ -126,31 +115,34 @@ impl Provable for BinarySM { } } - let mut inputs_basic = self.inputs_basic.lock().unwrap(); - inputs_basic.extend(_inputs_basic); - - while inputs_basic.len() >= PROVE_CHUNK_SIZE || (drain && !inputs_basic.is_empty()) { - let num_drained_basic = std::cmp::min(PROVE_CHUNK_SIZE, inputs_basic.len()); - let drained_inputs_basic = inputs_basic.drain(..num_drained_basic).collect::>(); + // Accumulate the basic operations, proving them once there are enough + if let Ok(mut inputs_basic) = self.inputs_basic.lock() { + inputs_basic.extend(_inputs_basic); - let binary_basic_sm_cloned = self.binary_basic_sm.clone(); + while inputs_basic.len() >= PROVE_CHUNK_SIZE || (drain && !inputs_basic.is_empty()) { + let num_drained_basic = std::cmp::min(PROVE_CHUNK_SIZE, inputs_basic.len()); + let drained_inputs_basic = + inputs_basic.drain(..num_drained_basic).collect::>(); - binary_basic_sm_cloned.prove(&drained_inputs_basic, false, scope); + self.binary_basic_sm.prove(&drained_inputs_basic, false); + } } - drop(inputs_basic); - let mut inputs_extension = self.inputs_extension.lock().unwrap(); - inputs_extension.extend(_inputs_extension); + // Accumulate the extension operations, proving them once there are enough + if let Ok(mut inputs_extension) = self.inputs_extension.lock() { + inputs_extension.extend(_inputs_extension); - while inputs_extension.len() >= PROVE_CHUNK_SIZE || (drain && !inputs_extension.is_empty()) - { - let num_drained_extension = std::cmp::min(PROVE_CHUNK_SIZE, inputs_extension.len()); - let drained_inputs_extension = - inputs_extension.drain(..num_drained_extension).collect::>(); - let binary_extension_sm_cloned = self.binary_extension_sm.clone(); + while inputs_extension.len() >= PROVE_CHUNK_SIZE || + (drain && !inputs_extension.is_empty()) + { + let num_drained_extension = std::cmp::min(PROVE_CHUNK_SIZE, inputs_extension.len()); + let drained_inputs_extension = + inputs_extension.drain(..num_drained_extension).collect::>(); - binary_extension_sm_cloned.prove(&drained_inputs_extension, false, scope); + self.binary_extension_sm.prove(&drained_inputs_extension, false); + } } - drop(inputs_extension); } } + +impl WitnessComponent for BinarySM {} diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index e0c8f206..dccedaf5 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -8,16 +8,47 @@ use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; use std::cmp::Ordering as CmpOrdering; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::*; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; +// 64 bits opcodes +const MINU_OP: u8 = ZiskOp::Minu.code(); +const MIN_OP: u8 = ZiskOp::Min.code(); +const MAXU_OP: u8 = ZiskOp::Maxu.code(); +const MAX_OP: u8 = ZiskOp::Max.code(); +pub const LT_ABS_NP_OP: u8 = 0x06; +pub const LT_ABS_PN_OP: u8 = 0x07; +pub const LTU_OP: u8 = ZiskOp::Ltu.code(); +const LT_OP: u8 = ZiskOp::Lt.code(); +pub const GT_OP: u8 = 0x0a; +const EQ_OP: u8 = ZiskOp::Eq.code(); +const ADD_OP: u8 = ZiskOp::Add.code(); +const SUB_OP: u8 = ZiskOp::Sub.code(); +const LEU_OP: u8 = ZiskOp::Leu.code(); +const LE_OP: u8 = ZiskOp::Le.code(); +const AND_OP: u8 = ZiskOp::And.code(); +const OR_OP: u8 = ZiskOp::Or.code(); +const XOR_OP: u8 = ZiskOp::Xor.code(); + +// 32 bits opcodes +const MINUW_OP: u8 = ZiskOp::MinuW.code(); +const MINW_OP: u8 = ZiskOp::MinW.code(); +const MAXUW_OP: u8 = ZiskOp::MaxuW.code(); +const MAXW_OP: u8 = ZiskOp::MaxW.code(); +const LTUW_OP: u8 = ZiskOp::LtuW.code(); +const LTW_OP: u8 = ZiskOp::LtW.code(); +const EQW_OP: u8 = ZiskOp::EqW.code(); +const ADDW_OP: u8 = ZiskOp::AddW.code(); +const SUBW_OP: u8 = ZiskOp::SubW.code(); +const LEUW_OP: u8 = ZiskOp::LeuW.code(); +const LEW_OP: u8 = ZiskOp::LeW.code(); + const BYTES: usize = 8; const HALF_BYTES: usize = BYTES / 2; +const MASK_U64: u64 = 0xFFFF_FFFF_FFFF_FFFF; pub struct BinaryBasicSM { wcm: Arc>, @@ -67,12 +98,8 @@ impl BinaryBasicSM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - /* as Provable>::prove( - self, - &[], - true, - scope, - );*/ + // If there are remaining inputs, prove them + self.prove(&[], true); self.binary_basic_table_sm.unregister_predecessor(); } @@ -80,66 +107,96 @@ impl BinaryBasicSM { pub fn operations() -> Vec { vec![ - // 64 bits opcodes - ZiskOp::Add.code(), - ZiskOp::Sub.code(), - ZiskOp::Ltu.code(), - ZiskOp::Lt.code(), - ZiskOp::Leu.code(), - ZiskOp::Le.code(), - ZiskOp::Eq.code(), - ZiskOp::Minu.code(), - ZiskOp::Min.code(), - ZiskOp::Maxu.code(), - ZiskOp::Max.code(), - ZiskOp::And.code(), - ZiskOp::Or.code(), - ZiskOp::Xor.code(), - // 32 bits opcodes - ZiskOp::AddW.code(), - ZiskOp::SubW.code(), - ZiskOp::LtuW.code(), - ZiskOp::LtW.code(), - ZiskOp::LeuW.code(), - ZiskOp::LeW.code(), - ZiskOp::EqW.code(), - ZiskOp::MinuW.code(), - ZiskOp::MinW.code(), - ZiskOp::MaxuW.code(), - ZiskOp::MaxW.code(), + MINU_OP, + MIN_OP, + MAXU_OP, + MAX_OP, + LT_ABS_NP_OP, + LT_ABS_PN_OP, + LTU_OP, + LT_OP, + GT_OP, + EQ_OP, + ADD_OP, + SUB_OP, + LEU_OP, + LE_OP, + AND_OP, + OR_OP, + XOR_OP, + MINUW_OP, + MINW_OP, + MAXUW_OP, + MAXW_OP, + LTUW_OP, + LTW_OP, + EQW_OP, + ADDW_OP, + SUBW_OP, + LEUW_OP, + LEW_OP, ] } - fn opcode_is_32_bits(opcode: ZiskOp) -> bool { - match opcode { - ZiskOp::Add | - ZiskOp::Sub | - ZiskOp::Ltu | - ZiskOp::Lt | - ZiskOp::Leu | - ZiskOp::Le | - ZiskOp::Eq | - ZiskOp::Minu | - ZiskOp::Min | - ZiskOp::Maxu | - ZiskOp::Max | - ZiskOp::And | - ZiskOp::Or | - ZiskOp::Xor => false, - - ZiskOp::AddW | - ZiskOp::SubW | - ZiskOp::LtuW | - ZiskOp::LtW | - ZiskOp::LeuW | - ZiskOp::LeW | - ZiskOp::EqW | - ZiskOp::MinuW | - ZiskOp::MinW | - ZiskOp::MaxuW | - ZiskOp::MaxW => true, - - _ => panic!("Binary basic opcode_is_32_bits() got invalid opcode={:?}", opcode), + fn opcode_is_32_bits(opcode: u8) -> bool { + const OPCODES_32_BITS: [u8; 11] = [ + MINUW_OP, MINW_OP, MAXUW_OP, MAXW_OP, LTUW_OP, LTW_OP, EQW_OP, ADDW_OP, SUBW_OP, + LEUW_OP, LEW_OP, + ]; + + OPCODES_32_BITS.contains(&opcode) + } + + fn lt_abs_np_execute(a: u64, b: u64) -> (u64, bool) { + let a_pos = (a ^ MASK_U64).wrapping_add(1); + if a_pos < b { + (1, true) + } else { + (0, false) + } + } + + fn lt_abs_pn_execute(a: u64, b: u64) -> (u64, bool) { + let b_pos = (b ^ MASK_U64).wrapping_add(1); + if a < b_pos { + (1, true) + } else { + (0, false) + } + } + + fn gt_execute(a: u64, b: u64) -> (u64, bool) { + if (a as i64) > (b as i64) { + (1, true) + } else { + (0, false) + } + } + + fn execute(opcode: u8, a: u64, b: u64) -> (u64, bool) { + let is_zisk_op = ZiskOp::try_from_code(opcode).is_ok(); + if is_zisk_op { + ZiskOp::execute(opcode, a, b) + } else { + match opcode { + LT_ABS_NP_OP => Self::lt_abs_np_execute(a, b), + LT_ABS_PN_OP => Self::lt_abs_pn_execute(a, b), + GT_OP => Self::gt_execute(a, b), + _ => panic!("BinaryBasicSM::execute() got invalid opcode={:?}", opcode), + } + } + } + + fn get_inital_carry(opcode: u8) -> u64 { + let is_zisk_op = ZiskOp::try_from_code(opcode).is_ok(); + if is_zisk_op { + 0 + } else { + match opcode { + LT_ABS_NP_OP | LT_ABS_PN_OP => 2, + GT_OP => 0, + _ => panic!("BinaryBasicSM::execute() got invalid opcode={:?}", opcode), + } } } @@ -152,13 +209,10 @@ impl BinaryBasicSM { let mut row: BinaryRow = Default::default(); // Execute the opcode - let c: u64; - let flag: bool; - (c, flag) = ZiskOp::execute(operation.opcode, operation.a, operation.b); - let _flag = flag; + let opcode = operation.opcode; + let (c, _) = Self::execute(opcode, operation.a, operation.b); // Set mode32 - let opcode = ZiskOp::try_from_code(operation.opcode).expect("Invalid ZiskOp opcode"); let mode32 = Self::opcode_is_32_bits(opcode); row.mode32 = F::from_bool(mode32); let mode64 = F::from_bool(!mode32); @@ -189,40 +243,66 @@ impl BinaryBasicSM { // Set use last carry and carry[], based on operation let mut cout: u64; - let mut cin: u64 = 0; + let mut cin: u64 = Self::get_inital_carry(opcode); let plast: [u64; 8] = if mode32 { [0, 0, 0, 1, 0, 0, 0, 0] } else { [0, 0, 0, 0, 0, 0, 0, 1] }; + // Calculate the byte that sets the carry let carry_byte = if mode32 { 3 } else { 7 }; let binary_basic_table_op: BinaryBasicTableOp; - let op = ZiskOp::try_from_code(operation.opcode).unwrap(); - match op { - ZiskOp::Add | ZiskOp::AddW => { + match opcode { + MINU_OP | MINUW_OP | MIN_OP | MINW_OP => { // Set opcode is min or max - row.op_is_min_max = F::zero(); + row.op_is_min_max = F::one(); + + let result_is_a: u64 = + if (operation.a == operation.b) || (operation.b == c_filtered) { 0 } else { 1 }; // Set the binary basic table opcode - binary_basic_table_op = BinaryBasicTableOp::Add; + binary_basic_table_op = if (opcode == MINU_OP) || (opcode == MINUW_OP) { + BinaryBasicTableOp::Minu + } else { + BinaryBasicTableOp::Min + }; // Set use last carry to zero row.use_last_carry = F::zero(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // Apply the logic to every byte for i in 0..8 { // Calculate carry let previous_cin = cin; - let result = cin + a_bytes[i] as u64 + b_bytes[i] as u64; - cout = result >> 8; - cin = if i == carry_byte { 0 } else { cout }; + match a_bytes[i].cmp(&b_bytes[i]) { + CmpOrdering::Greater => { + cout = 0; + } + CmpOrdering::Less => { + cout = 1; + } + CmpOrdering::Equal => { + cout = cin; + } + } + + // If the chunk is signed, then the result is the sign of a + if (binary_basic_table_op == BinaryBasicTableOp::Min) && + (plast[i] == 1) && + (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + { + cout = if (a_bytes[i] & 0x80) != 0 { 1 } else { 0 }; + } + if mode32 && (i >= 4) { + cout = 0; + } + cin = cout; row.carry[i] = F::from_canonical_u64(cin); //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; - let flags = cin; - - // Set a and b bytes - let a_byte = if mode32 && (i >= 4) { c_bytes[3] } else { a_bytes[i] }; - let b_byte = if mode32 && (i >= 4) { 0 } else { b_bytes[i] }; + let flags = cout + 2 + 4 * result_is_a; // Store the required in the vector let row = BinaryBasicTableSM::::calculate_table_row( @@ -231,41 +311,66 @@ impl BinaryBasicSM { } else { binary_basic_table_op }, - a_byte as u64, - b_byte as u64, + if mode32 && (i >= 4) { c_bytes[3] as u64 } else { a_bytes[i] as u64 }, + b_bytes[i] as u64, previous_cin, plast[i], - c_bytes[i] as u64, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Sub | ZiskOp::SubW => { + MAXU_OP | MAXUW_OP | MAX_OP | MAXW_OP => { // Set opcode is min or max - row.op_is_min_max = F::zero(); + row.op_is_min_max = F::one(); + + let result_is_a: u64 = + if (operation.a == operation.b) || (operation.b == c_filtered) { 0 } else { 1 }; // Set the binary basic table opcode - binary_basic_table_op = BinaryBasicTableOp::Sub; + binary_basic_table_op = if (opcode == MAXU_OP) || (opcode == MAXUW_OP) { + BinaryBasicTableOp::Maxu + } else { + BinaryBasicTableOp::Max + }; // Set use last carry to zero row.use_last_carry = F::zero(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // Apply the logic to every byte for i in 0..8 { // Calculate carry let previous_cin = cin; - cout = if a_bytes[i] as u64 >= (b_bytes[i] as u64 + cin) { 0 } else { 1 }; - cin = if i == carry_byte { 0 } else { cout }; + match a_bytes[i].cmp(&b_bytes[i]) { + CmpOrdering::Greater => { + cout = 1; + } + CmpOrdering::Less => { + cout = 0; + } + CmpOrdering::Equal => { + cout = cin; + } + } + + // If the chunk is signed, then the result is the sign of a + if (binary_basic_table_op == BinaryBasicTableOp::Max) && + (plast[i] == 1) && + (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + { + cout = if (a_bytes[i] & 0x80) != 0 { 0 } else { 1 }; + } + if mode32 && (i >= 4) { + cout = 0; + } + cin = cout; row.carry[i] = F::from_canonical_u64(cin); //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; - let flags = cin; - - // Set a and b bytes - let a_byte = if mode32 && (i >= 4) { c_bytes[3] } else { a_bytes[i] }; - let b_byte = if mode32 && (i >= 4) { 0 } else { b_bytes[i] }; + let flags = cout + 2 + 4 * result_is_a; // Store the required in the vector let row = BinaryBasicTableSM::::calculate_table_row( @@ -274,23 +379,133 @@ impl BinaryBasicSM { } else { binary_basic_table_op }, - a_byte as u64, - b_byte as u64, + if mode32 && (i >= 4) { c_bytes[3] as u64 } else { a_bytes[i] as u64 }, + b_bytes[i] as u64, previous_cin, plast[i], - c_bytes[i] as u64, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Ltu | ZiskOp::LtuW | ZiskOp::Lt | ZiskOp::LtW => { + LT_ABS_NP_OP => { // Set opcode is min or max row.op_is_min_max = F::zero(); // Set the binary basic table opcode - binary_basic_table_op = if (op == ZiskOp::Ltu) || (op == ZiskOp::LtuW) { + binary_basic_table_op = BinaryBasicTableOp::LtAbsNP; + + // Set use last carry + row.use_last_carry = F::one(); + + // Set has initial carry + row.has_initial_carry = F::one(); + + // Apply the logic to every byte + for i in 0..8 { + let _clt = cin & 0x01; + let _cop = (cin & 0x02) >> 1; + + let _a = (a_bytes[i] as u64 ^ 0xFF) + _cop; + let _b = b_bytes[i] as u64; + + // Calculate the output carry + let previous_cin = cin; + match (_a & 0xFF).cmp(&_b) { + CmpOrdering::Less => { + cout = 1; + } + CmpOrdering::Equal => { + cout = _clt; + } + CmpOrdering::Greater => { + cout = 0; + } + } + + cout += 2 * (_a >> 8); + row.carry[i] = F::from_canonical_u64(cout); + + // Set carry for next iteration + cin = cout; + + //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; + let flags = cout + 8 * plast[i]; + + // Store the required in the vector + let row = BinaryBasicTableSM::::calculate_table_row( + binary_basic_table_op, + a_bytes[i] as u64, + b_bytes[i] as u64, + previous_cin, + plast[i], + flags, + ); + multiplicity[row as usize] += 1; + } + } + LT_ABS_PN_OP => { + // Set opcode is min or max + row.op_is_min_max = F::zero(); + + // Set the binary basic table opcode + binary_basic_table_op = BinaryBasicTableOp::LtAbsPN; + + // Set use last carry + row.use_last_carry = F::one(); + + // Set has initial carry + row.has_initial_carry = F::one(); + + // Apply the logic to every byte + for i in 0..8 { + let _clt = cin & 0x1; + let _cop = (cin & 0x02) >> 1; + + let _a = a_bytes[i] as u64; + let _b = (b_bytes[i] as u64 ^ 0xFF) + _cop; + + // Calculate the output carry + let previous_cin = cin; + match _a.cmp(&(_b & 0xFF)) { + CmpOrdering::Less => { + cout = 1; + } + CmpOrdering::Equal => { + cout = _clt; + } + CmpOrdering::Greater => { + cout = 0; + } + } + + cout += 2 * (_b >> 8); + row.carry[i] = F::from_canonical_u64(cout); + + // Set carry for next iteration + cin = cout; + + //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; + let flags = cout + 8 * plast[i]; + + // Store the required in the vector + let row = BinaryBasicTableSM::::calculate_table_row( + binary_basic_table_op, + a_bytes[i] as u64, + b_bytes[i] as u64, + previous_cin, + plast[i], + flags, + ); + multiplicity[row as usize] += 1; + } + } + LTU_OP | LTUW_OP | LT_OP | LTW_OP => { + // Set opcode is min or max + row.op_is_min_max = F::zero(); + + // Set the binary basic table opcode + binary_basic_table_op = if (opcode == LTU_OP) || (opcode == LTUW_OP) { BinaryBasicTableOp::Ltu } else { BinaryBasicTableOp::Lt @@ -299,6 +514,9 @@ impl BinaryBasicSM { // Set use last carry to one row.use_last_carry = F::one(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // Apply the logic to every byte for i in 0..8 { // Calculate carry @@ -339,66 +557,65 @@ impl BinaryBasicSM { b_bytes[i] as u64, previous_cin, plast[i], - if i == 7 { c_bytes[0] as u64 } else { 0 }, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Leu | ZiskOp::LeuW | ZiskOp::Le | ZiskOp::LeW => { + GT_OP => { // Set opcode is min or max row.op_is_min_max = F::zero(); // Set the binary basic table opcode - binary_basic_table_op = if (op == ZiskOp::Leu) || (op == ZiskOp::LeuW) { - BinaryBasicTableOp::Leu - } else { - BinaryBasicTableOp::Le - }; + binary_basic_table_op = BinaryBasicTableOp::Gt; // Set use last carry to one row.use_last_carry = F::one(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // Apply the logic to every byte for i in 0..8 { // Calculate carry let previous_cin = cin; - cout = 0; - if a_bytes[i] <= b_bytes[i] { - cout = 1; + match a_bytes[i].cmp(&b_bytes[i]) { + CmpOrdering::Greater => { + cout = 1; + } + CmpOrdering::Less => { + cout = 0; + } + CmpOrdering::Equal => { + cout = cin; + } } - if (binary_basic_table_op == BinaryBasicTableOp::Le) && - (plast[i] == 1) && - (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) - { - cout = c; + + // The result is the sign of b + if (plast[i] == 1) && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { + cout = if b_bytes[i] & 0x80 != 0 { 1 } else { 0 }; } + row.carry[i] = F::from_canonical_u64(cout); + + // Set carry for next iteration cin = cout; - row.carry[i] = F::from_canonical_u64(cin); //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; - let flags = cin + 8 * plast[i]; + let flags = cout + 8 * plast[i]; // Store the required in the vector let row = BinaryBasicTableSM::::calculate_table_row( - if mode32 && (i >= 4) { - BinaryBasicTableOp::Ext32 - } else { - binary_basic_table_op - }, + binary_basic_table_op, a_bytes[i] as u64, b_bytes[i] as u64, previous_cin, plast[i], - if i == 7 { c_bytes[0] as u64 } else { 0 }, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Eq | ZiskOp::EqW => { + EQ_OP | EQW_OP => { // Set opcode is min or max row.op_is_min_max = F::zero(); @@ -408,6 +625,9 @@ impl BinaryBasicSM { // Set use last carry to one row.use_last_carry = F::one(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // Apply the logic to every byte for i in 0..8 { // Calculate carry @@ -437,61 +657,39 @@ impl BinaryBasicSM { b_bytes[i] as u64, previous_cin, plast[i], - if i == 7 { c_bytes[0] as u64 } else { 0 }, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Minu | ZiskOp::MinuW | ZiskOp::Min | ZiskOp::MinW => { + ADD_OP | ADDW_OP => { // Set opcode is min or max - row.op_is_min_max = F::one(); - - let result_is_a: u64 = - if (operation.a == operation.b) || (operation.b == c_filtered) { 0 } else { 1 }; + row.op_is_min_max = F::zero(); // Set the binary basic table opcode - binary_basic_table_op = if (op == ZiskOp::Minu) || (op == ZiskOp::MinuW) { - BinaryBasicTableOp::Minu - } else { - BinaryBasicTableOp::Min - }; + binary_basic_table_op = BinaryBasicTableOp::Add; // Set use last carry to zero row.use_last_carry = F::zero(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // Apply the logic to every byte for i in 0..8 { // Calculate carry let previous_cin = cin; - match a_bytes[i].cmp(&b_bytes[i]) { - CmpOrdering::Greater => { - cout = 0; - } - CmpOrdering::Less => { - cout = 1; - } - CmpOrdering::Equal => { - cout = cin; - } - } - - // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op == BinaryBasicTableOp::Min) && - (plast[i] == 1) && - (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) - { - cout = if (a_bytes[i] & 0x80) != 0 { 1 } else { 0 }; - } - if mode32 && (i >= 4) { - cout = 0; - } - cin = cout; + let result = cin + a_bytes[i] as u64 + b_bytes[i] as u64; + cout = result >> 8; + cin = if i == carry_byte { 0 } else { cout }; row.carry[i] = F::from_canonical_u64(cin); //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; - let flags = cout + 2 + 4 * result_is_a; + let flags = cin; + + // Set a and b bytes + let a_byte = if mode32 && (i >= 4) { c_bytes[3] } else { a_bytes[i] }; + let b_byte = if mode32 && (i >= 4) { 0 } else { b_bytes[i] }; // Store the required in the vector let row = BinaryBasicTableSM::::calculate_table_row( @@ -500,65 +698,95 @@ impl BinaryBasicSM { } else { binary_basic_table_op }, - if mode32 && (i >= 4) { c_bytes[3] as u64 } else { a_bytes[i] as u64 }, - b_bytes[i] as u64, + a_byte as u64, + b_byte as u64, previous_cin, plast[i], - c_bytes[i] as u64, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Maxu | ZiskOp::MaxuW | ZiskOp::Max | ZiskOp::MaxW => { + SUB_OP | SUBW_OP => { // Set opcode is min or max - row.op_is_min_max = F::one(); + row.op_is_min_max = F::zero(); - let result_is_a: u64 = - if (operation.a == operation.b) || (operation.b == c_filtered) { 0 } else { 1 }; + // Set the binary basic table opcode + binary_basic_table_op = BinaryBasicTableOp::Sub; + + // Set use last carry to zero + row.use_last_carry = F::zero(); + + // Set has initial carry + row.has_initial_carry = F::zero(); + + // Apply the logic to every byte + for i in 0..8 { + // Calculate carry + let previous_cin = cin; + cout = if a_bytes[i] as u64 >= (b_bytes[i] as u64 + cin) { 0 } else { 1 }; + cin = if i == carry_byte { 0 } else { cout }; + row.carry[i] = F::from_canonical_u64(cin); + + //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; + let flags = cin; + + // Set a and b bytes + let a_byte = if mode32 && (i >= 4) { c_bytes[3] } else { a_bytes[i] }; + let b_byte = if mode32 && (i >= 4) { 0 } else { b_bytes[i] }; + + // Store the required in the vector + let row = BinaryBasicTableSM::::calculate_table_row( + if mode32 && (i >= 4) { + BinaryBasicTableOp::Ext32 + } else { + binary_basic_table_op + }, + a_byte as u64, + b_byte as u64, + previous_cin, + plast[i], + flags, + ); + multiplicity[row as usize] += 1; + } + } + LEU_OP | LEUW_OP | LE_OP | LEW_OP => { + // Set opcode is min or max + row.op_is_min_max = F::zero(); // Set the binary basic table opcode - binary_basic_table_op = if (op == ZiskOp::Maxu) || (op == ZiskOp::MaxuW) { - BinaryBasicTableOp::Maxu + binary_basic_table_op = if (opcode == LEU_OP) || (opcode == LEUW_OP) { + BinaryBasicTableOp::Leu } else { - BinaryBasicTableOp::Max + BinaryBasicTableOp::Le }; - // Set use last carry to zero - row.use_last_carry = F::zero(); + // Set use last carry to one + row.use_last_carry = F::one(); + + // Set has initial carry + row.has_initial_carry = F::zero(); // Apply the logic to every byte for i in 0..8 { // Calculate carry let previous_cin = cin; - match a_bytes[i].cmp(&b_bytes[i]) { - CmpOrdering::Greater => { - cout = 1; - } - CmpOrdering::Less => { - cout = 0; - } - CmpOrdering::Equal => { - cout = cin; - } + cout = 0; + if a_bytes[i] <= b_bytes[i] { + cout = 1; } - - // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op == BinaryBasicTableOp::Max) && + if (binary_basic_table_op == BinaryBasicTableOp::Le) && (plast[i] == 1) && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { - cout = if (a_bytes[i] & 0x80) != 0 { 0 } else { 1 }; - } - if mode32 && (i >= 4) { - cout = 0; + cout = c; } cin = cout; row.carry[i] = F::from_canonical_u64(cin); //FLAGS[i] = cout + 2*op_is_min_max + 4*result_is_a + 8*USE_CARRY[i]*plast; - let flags = cout + 2 + 4 * result_is_a; + let flags = cin + 8 * plast[i]; // Store the required in the vector let row = BinaryBasicTableSM::::calculate_table_row( @@ -567,18 +795,16 @@ impl BinaryBasicSM { } else { binary_basic_table_op }, - if mode32 && (i >= 4) { c_bytes[3] as u64 } else { a_bytes[i] as u64 }, + a_bytes[i] as u64, b_bytes[i] as u64, previous_cin, plast[i], - c_bytes[i] as u64, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::And => { + AND_OP => { // Set opcode is min or max row.op_is_min_max = F::zero(); @@ -587,6 +813,9 @@ impl BinaryBasicSM { row.use_last_carry = F::zero(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // No carry for i in 0..8 { row.carry[i] = F::zero(); @@ -601,14 +830,12 @@ impl BinaryBasicSM { b_bytes[i] as u64, 0, plast[i], - c_bytes[i] as u64, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Or => { + OR_OP => { // Set opcode is min or max row.op_is_min_max = F::zero(); @@ -617,6 +844,9 @@ impl BinaryBasicSM { row.use_last_carry = F::zero(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // No carry for i in 0..8 { row.carry[i] = F::zero(); @@ -631,22 +861,24 @@ impl BinaryBasicSM { b_bytes[i] as u64, 0, plast[i], - c_bytes[i] as u64, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - ZiskOp::Xor => { + XOR_OP => { // Set opcode is min or max row.op_is_min_max = F::zero(); // Set the binary basic table opcode binary_basic_table_op = BinaryBasicTableOp::Xor; + // Set use last carry to zero row.use_last_carry = F::zero(); + // Set has initial carry + row.has_initial_carry = F::zero(); + // No carry for i in 0..8 { row.carry[i] = F::zero(); @@ -661,14 +893,39 @@ impl BinaryBasicSM { b_bytes[i] as u64, 0, plast[i], - c_bytes[i] as u64, flags, - i as u64, ); multiplicity[row as usize] += 1; } } - _ => panic!("BinaryBasicSM::process_slice() found invalid opcode={}", operation.opcode), + _ => panic!("BinaryBasicSM::process_slice() found invalid opcode={}", opcode), + } + + // Set cout + let cout32 = row.carry[HALF_BYTES - 1]; + let cout64 = row.carry[BYTES - 1]; + row.cout = mode64 * (cout64 - cout32) + cout32; + + // Set result_is_a + row.result_is_a = row.op_is_min_max * row.cout; + + // Set use_last_carry_mode32 and use_last_carry_mode64 + row.use_last_carry_mode32 = F::from_bool(mode32) * row.use_last_carry; + row.use_last_carry_mode64 = mode64 * row.use_last_carry; + + // Set micro opcode + row.m_op = F::from_canonical_u8(binary_basic_table_op as u8); + + // Set m_op_or_ext + let ext_32_op = F::from_canonical_u8(BinaryBasicTableOp::Ext32 as u8); + row.m_op_or_ext = mode64 * (row.m_op - ext_32_op) + ext_32_op; + + // Set free_in_a_or_c and free_in_b_or_zero + for i in 0..HALF_BYTES { + row.free_in_a_or_c[i] = mode64 * + (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) + + row.free_in_c[HALF_BYTES - 1]; + row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES]; } // Set cout @@ -746,12 +1003,11 @@ impl BinaryBasicSM { timer_stop_and_log_trace!(BINARY_TRACE); timer_start_trace!(BINARY_PADDING); + // Note: We can choose any operation that trivially satisfies the constraints on padding + // rows let padding_row = BinaryRow:: { - m_op: F::from_canonical_u8(0x20), - m_op_or_ext: F::from_canonical_u8(0x20), - multiplicity: F::zero(), - main_step: F::zero(), /* TODO: remove, since main_step is just for - * debugging */ + m_op: F::from_canonical_u8(AND_OP), + m_op_or_ext: F::from_canonical_u8(AND_OP), ..Default::default() }; @@ -769,8 +1025,6 @@ impl BinaryBasicSM { 0, last as u64, 0, - 0, - 0, ); multiplicity_table[row as usize] += multiplicity; } @@ -785,12 +1039,8 @@ impl BinaryBasicSM { drop(multiplicity_table); }); } -} - -impl WitnessComponent for BinaryBasicSM {} -impl Provable for BinaryBasicSM { - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, _scope: &Scope) { + pub fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool) { if let Ok(mut inputs) = self.inputs.lock() { inputs.extend_from_slice(operations); @@ -828,3 +1078,5 @@ impl Provable for BinaryBasicSM { } } } + +impl WitnessComponent for BinaryBasicSM {} diff --git a/state-machines/binary/src/binary_basic_table.rs b/state-machines/binary/src/binary_basic_table.rs index 374910f3..81ba43d8 100644 --- a/state-machines/binary/src/binary_basic_table.rs +++ b/state-machines/binary/src/binary_basic_table.rs @@ -14,21 +14,24 @@ use zisk_pil::{BinaryTableTrace, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] pub enum BinaryBasicTableOp { - Add = 0x02, - Sub = 0x03, - Ltu = 0x04, - Lt = 0x05, - Leu = 0x06, - Le = 0x07, - Eq = 0x08, - Minu = 0x09, - Min = 0x0a, - Maxu = 0x0b, - Max = 0x0c, - And = 0x20, - Or = 0x21, - Xor = 0x22, - Ext32 = 0x23, + Minu = 0x02, + Min = 0x03, + Maxu = 0x04, + Max = 0x05, + LtAbsNP = 0x06, + LtAbsPN = 0x07, + Ltu = 0x08, + Lt = 0x09, + Gt = 0x0a, + Eq = 0x0b, + Add = 0x0c, + Sub = 0x0d, + Leu = 0x0e, + Le = 0x0f, + And = 0x10, + Or = 0x11, + Xor = 0x12, + Ext32 = 0x13, } pub struct BinaryBasicTableSM { @@ -76,19 +79,20 @@ impl BinaryBasicTableSM { } } + // TODO: Add new ops? pub fn operations() -> Vec { vec![ - ZiskOp::Add.code(), - ZiskOp::Sub.code(), - ZiskOp::Ltu.code(), - ZiskOp::Lt.code(), - ZiskOp::Leu.code(), - ZiskOp::Le.code(), - ZiskOp::Eq.code(), ZiskOp::Minu.code(), ZiskOp::Min.code(), ZiskOp::Maxu.code(), ZiskOp::Max.code(), + ZiskOp::Ltu.code(), + ZiskOp::Lt.code(), + ZiskOp::Eq.code(), + ZiskOp::Add.code(), + ZiskOp::Sub.code(), + ZiskOp::Leu.code(), + ZiskOp::Le.code(), ZiskOp::And.code(), ZiskOp::Or.code(), ZiskOp::Xor.code(), @@ -112,10 +116,14 @@ impl BinaryBasicTableSM { b: u64, cin: u64, last: u64, - _c: u64, flags: u64, - _i: u64, ) -> u64 { + debug_assert!(a <= 0xFF); + debug_assert!(b <= 0xFF); + debug_assert!(cin <= 0x03); + debug_assert!(last <= 0x01); + debug_assert!(flags <= 0x0F); + // Calculate the different row offset contributors, according to the PIL if opcode == BinaryBasicTableOp::Ext32 { let offset_a: u64 = a; @@ -141,42 +149,46 @@ impl BinaryBasicTableSM { let offset_opcode: u64 = Self::offset_opcode(opcode); offset_a + offset_b + offset_last + offset_cin + offset_result_is_a + offset_opcode - //assert!(row < self.num_rows as u64); } } fn opcode_has_last(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Add | - BinaryBasicTableOp::Sub | - BinaryBasicTableOp::Ltu | - BinaryBasicTableOp::Lt | - BinaryBasicTableOp::Leu | - BinaryBasicTableOp::Le | - BinaryBasicTableOp::Eq | BinaryBasicTableOp::Minu | BinaryBasicTableOp::Min | BinaryBasicTableOp::Maxu | BinaryBasicTableOp::Max | + BinaryBasicTableOp::LtAbsNP | + BinaryBasicTableOp::LtAbsPN | + BinaryBasicTableOp::Ltu | + BinaryBasicTableOp::Lt | + BinaryBasicTableOp::Gt | + BinaryBasicTableOp::Eq | + BinaryBasicTableOp::Add | + BinaryBasicTableOp::Sub | + BinaryBasicTableOp::Leu | + BinaryBasicTableOp::Le | BinaryBasicTableOp::And | BinaryBasicTableOp::Or | BinaryBasicTableOp::Xor => true, BinaryBasicTableOp::Ext32 => false, - //_ => panic!("BinaryBasicTableSM::opcode_has_last() got invalid opcode={:?}", opcode), } } fn opcode_has_cin(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Add | - BinaryBasicTableOp::Sub | - BinaryBasicTableOp::Ltu | - BinaryBasicTableOp::Lt | - BinaryBasicTableOp::Eq | BinaryBasicTableOp::Minu | BinaryBasicTableOp::Min | BinaryBasicTableOp::Maxu | - BinaryBasicTableOp::Max => true, + BinaryBasicTableOp::Max | + BinaryBasicTableOp::LtAbsNP | + BinaryBasicTableOp::LtAbsPN | + BinaryBasicTableOp::Ltu | + BinaryBasicTableOp::Lt | + BinaryBasicTableOp::Gt | + BinaryBasicTableOp::Eq | + BinaryBasicTableOp::Add | + BinaryBasicTableOp::Sub => true, BinaryBasicTableOp::Leu | BinaryBasicTableOp::Le | @@ -184,29 +196,30 @@ impl BinaryBasicTableSM { BinaryBasicTableOp::Or | BinaryBasicTableOp::Xor | BinaryBasicTableOp::Ext32 => false, - //_ => panic!("BinaryBasicTableSM::opcode_has_cin() got invalid opcode={:?}", opcode), } } fn opcode_result_is_a(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Minu - | BinaryBasicTableOp::Min - | BinaryBasicTableOp::Maxu - | BinaryBasicTableOp::Max => true, - - BinaryBasicTableOp::Add - | BinaryBasicTableOp::Sub - | BinaryBasicTableOp::Ltu - | BinaryBasicTableOp::Lt - | BinaryBasicTableOp::Leu - | BinaryBasicTableOp::Le - | BinaryBasicTableOp::Eq - | BinaryBasicTableOp::And - | BinaryBasicTableOp::Or - | BinaryBasicTableOp::Xor - | BinaryBasicTableOp::Ext32 => false, - //_ => panic!("BinaryBasicTableSM::opcode_result_is_a() got invalid opcode={:?}", opcode), + BinaryBasicTableOp::Minu | + BinaryBasicTableOp::Min | + BinaryBasicTableOp::Maxu | + BinaryBasicTableOp::Max => true, + + BinaryBasicTableOp::LtAbsNP | + BinaryBasicTableOp::LtAbsPN | + BinaryBasicTableOp::Ltu | + BinaryBasicTableOp::Lt | + BinaryBasicTableOp::Gt | + BinaryBasicTableOp::Eq | + BinaryBasicTableOp::Add | + BinaryBasicTableOp::Sub | + BinaryBasicTableOp::Leu | + BinaryBasicTableOp::Le | + BinaryBasicTableOp::And | + BinaryBasicTableOp::Or | + BinaryBasicTableOp::Xor | + BinaryBasicTableOp::Ext32 => false, } } @@ -216,18 +229,20 @@ impl BinaryBasicTableSM { BinaryBasicTableOp::Min => P2_19, BinaryBasicTableOp::Maxu => 2 * P2_19, BinaryBasicTableOp::Max => 3 * P2_19, - BinaryBasicTableOp::Ltu => 4 * P2_19, - BinaryBasicTableOp::Lt => 4 * P2_19 + P2_18, - BinaryBasicTableOp::Eq => 4 * P2_19 + 2 * P2_18, - BinaryBasicTableOp::Add => 4 * P2_19 + 3 * P2_18, - BinaryBasicTableOp::Sub => 4 * P2_19 + 4 * P2_18, - BinaryBasicTableOp::Leu => 4 * P2_19 + 5 * P2_18, - BinaryBasicTableOp::Le => 4 * P2_19 + 5 * P2_18 + P2_17, - BinaryBasicTableOp::And => 4 * P2_19 + 5 * P2_18 + 2 * P2_17, - BinaryBasicTableOp::Or => 4 * P2_19 + 5 * P2_18 + 3 * P2_17, - BinaryBasicTableOp::Xor => 4 * P2_19 + 5 * P2_18 + 4 * P2_17, - BinaryBasicTableOp::Ext32 => 4 * P2_19 + 5 * P2_18 + 5 * P2_17, - //_ => panic!("BinaryBasicTableSM::offset_opcode() got invalid opcode={:?}", opcode), + BinaryBasicTableOp::LtAbsNP => 4 * P2_19, + BinaryBasicTableOp::LtAbsPN => 5 * P2_19, + BinaryBasicTableOp::Ltu => 6 * P2_19, + BinaryBasicTableOp::Lt => 6 * P2_19 + P2_18, + BinaryBasicTableOp::Gt => 6 * P2_19 + 2 * P2_18, + BinaryBasicTableOp::Eq => 6 * P2_19 + 3 * P2_18, + BinaryBasicTableOp::Add => 6 * P2_19 + 4 * P2_18, + BinaryBasicTableOp::Sub => 6 * P2_19 + 5 * P2_18, + BinaryBasicTableOp::Leu => 6 * P2_19 + 6 * P2_18, + BinaryBasicTableOp::Le => 6 * P2_19 + 6 * P2_18 + P2_17, + BinaryBasicTableOp::And => 6 * P2_19 + 6 * P2_18 + 2 * P2_17, + BinaryBasicTableOp::Or => 6 * P2_19 + 6 * P2_18 + 3 * P2_17, + BinaryBasicTableOp::Xor => 6 * P2_19 + 6 * P2_18 + 4 * P2_17, + BinaryBasicTableOp::Ext32 => 6 * P2_19 + 6 * P2_18 + 5 * P2_17, } } diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index 5077107c..129b8202 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -14,8 +14,6 @@ use pil_std_lib::Std; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::*; @@ -32,6 +30,8 @@ const SIGN_BYTE: u64 = 0x80; const LS_5_BITS: u64 = 0x1F; const LS_6_BITS: u64 = 0x3F; +const SE_W_OP: u8 = 0x39; + pub struct BinaryExtensionSM { // Witness computation manager wcm: Arc>, @@ -407,8 +407,10 @@ impl BinaryExtensionSM { timer_stop_and_log_debug!(BINARY_EXTENSION_TRACE); timer_start_debug!(BINARY_EXTENSION_PADDING); + // Note: We can choose any operation that trivially satisfies the constraints on padding + // rows let padding_row = - BinaryExtensionRow:: { op: F::from_canonical_u64(0x25), ..Default::default() }; + BinaryExtensionRow:: { op: F::from_canonical_u8(SE_W_OP), ..Default::default() }; for i in operations.len()..air.num_rows() { trace_buffer[i] = padding_row; @@ -448,12 +450,8 @@ impl BinaryExtensionSM { drop(range_check); }); } -} - -impl WitnessComponent for BinaryExtensionSM {} -impl Provable for BinaryExtensionSM { - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, _scope: &Scope) { + pub fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool) { if let Ok(mut inputs) = self.inputs.lock() { inputs.extend_from_slice(operations); @@ -494,3 +492,5 @@ impl Provable for BinaryExtensio } } } + +impl WitnessComponent for BinaryExtensionSM {} diff --git a/state-machines/binary/src/binary_extension_table.rs b/state-machines/binary/src/binary_extension_table.rs index 6916b78a..1b7c2c38 100644 --- a/state-machines/binary/src/binary_extension_table.rs +++ b/state-machines/binary/src/binary_extension_table.rs @@ -14,15 +14,15 @@ use zisk_pil::{BinaryExtensionTableTrace, BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_A #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] pub enum BinaryExtensionTableOp { - Sll = 0x0d, - Srl = 0x0e, - Sra = 0x0f, - SllW = 0x1d, - SrlW = 0x1e, - SraW = 0x1f, - SignExtendB = 0x23, - SignExtendH = 0x24, - SignExtendW = 0x25, + Sll = 0x31, + Srl = 0x32, + Sra = 0x33, + SllW = 0x34, + SrlW = 0x35, + SraW = 0x36, + SignExtendB = 0x37, + SignExtendH = 0x38, + SignExtendW = 0x39, } pub struct BinaryExtensionTableSM { @@ -71,11 +71,13 @@ impl BinaryExtensionTableSM { } pub fn operations() -> Vec { - // TODO! Review this codes vec![ ZiskOp::Sll.code(), ZiskOp::Srl.code(), ZiskOp::Sra.code(), + ZiskOp::SllW.code(), + ZiskOp::SrlW.code(), + ZiskOp::SraW.code(), ZiskOp::SignExtendB.code(), ZiskOp::SignExtendH.code(), ZiskOp::SignExtendW.code(), @@ -92,17 +94,17 @@ impl BinaryExtensionTableSM { //lookup_proves(BINARY_EXTENSION_TABLE_ID, [OP, OFFSET, A, B, C0, C1], multiplicity); pub fn calculate_table_row(opcode: BinaryExtensionTableOp, offset: u64, a: u64, b: u64) -> u64 { + debug_assert!(offset <= 0x07); + debug_assert!(a <= 0xFF); + debug_assert!(b <= 0xFF); + // Calculate the different row offset contributors, according to the PIL - assert!(a <= 0xFF); let offset_a: u64 = a; - assert!(offset < 0x08); let offset_offset: u64 = offset * P2_8; - assert!(b <= 0xFF); let offset_b: u64 = b * P2_11; let offset_opcode: u64 = Self::offset_opcode(opcode); offset_a + offset_offset + offset_b + offset_opcode - //assert!(row < self.num_rows as u64); } fn offset_opcode(opcode: BinaryExtensionTableOp) -> u64 { @@ -116,7 +118,6 @@ impl BinaryExtensionTableSM { BinaryExtensionTableOp::SignExtendB => 6 * P2_19, BinaryExtensionTableOp::SignExtendH => 6 * P2_19 + P2_11, BinaryExtensionTableOp::SignExtendW => 6 * P2_19 + 2 * P2_11, - //_ => panic!("BinaryExtensionTableSM::offset_opcode() got invalid opcode={:?}", opcode), } } diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index b787f3c2..740a6e63 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -275,6 +275,6 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope lookup_assumes(ROM_BUS_ID, [pc, a_offset_imm0, a_imm1, b_offset_imm0, b_imm1, ind_width, op, store_offset, jmp_offset1, jmp_offset2, rom_flags], sel: 1 - SEGMENT_L1); - direct_global_update(MAIN_CONTINUATION_ID, cols: [0, 0, 4096, 0, 0], bus_type: PIOP_BUS_SUM, proves: 1); - direct_global_update(MAIN_CONTINUATION_ID, cols: [0, 1, 0x10000000, 0, 0], bus_type: PIOP_BUS_SUM, proves: 0); + direct_global_update_proves(MAIN_CONTINUATION_ID, cols: [0, 0, 4096, 0, 0, 0, 0, 0], bus_type: PIOP_BUS_SUM); + direct_global_update_assumes(MAIN_CONTINUATION_ID, cols: [0, 1, 0x1000_0000, 0, 0, 0, 0, 0], bus_type: PIOP_BUS_SUM); } \ No newline at end of file diff --git a/tools/verify_all.sh b/tools/verify_all.sh index 8e46e917..09abe5b5 100755 --- a/tools/verify_all.sh +++ b/tools/verify_all.sh @@ -4,7 +4,7 @@ echo "Verify all ELF files found in a directory" # Check that at least one argument has been passed if [ "$#" -lt 1 ]; then - echo "Usage: $0 [-l/--list -b/--begin -e/--end ]" + echo "Usage: $0 [-l/--list -b/--begin -e/--end -d/--debug]" exit 1 fi @@ -21,11 +21,13 @@ echo "Verifying ELF files found in directory ${DIR}" LIST=0 BEGIN=0 END=0 +DEBUG=0 while [[ "$#" -gt 0 ]]; do case $1 in -l|--list) LIST=1 ;; -b|--begin) BEGIN=$2; shift; ;; -e|--end) END=$2; shift; ;; + -d|--debug) DEBUG=1 ;; *) echo "Unknown parameter passed: $1"; exit 1 ;; esac shift @@ -46,7 +48,7 @@ if [ $BEGIN -ne 0 ]; then echo "Beginning at file ${BEGIN}"; fi if [ $END -ne 0 ]; then - echo "Beginning at file ${END}"; + echo "Ending at file ${END}"; fi # If just listing, exit @@ -58,7 +60,7 @@ fi # Record the number of files MAX_COUNTER=${COUNTER} -# Create and empty input file +# Create an empty input file INPUT_FILE="/tmp/empty_input.bin" touch $INPUT_FILE @@ -82,6 +84,13 @@ do # Varify the contraints for this file echo "" echo "Verifying file ${COUNTER} of ${MAX_COUNTER}: ${ELF_FILE}" - (cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli verify-constraints --witness-lib ../zisk/target/release/libzisk_witness.so --rom $ELF_FILE -i $INPUT_FILE --proving-key ../zisk/build/provingKey) + + if [ $DEBUG -eq 1 ]; then + # Run with debug flag + (cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli verify-constraints --witness-lib ../zisk/target/release/libzisk_witness.so --rom $ELF_FILE -i $INPUT_FILE --proving-key ../zisk/build/provingKey -d) + else + # Run without debug flag + (cargo build --release && cd ../pil2-proofman; cargo run --release --bin proofman-cli verify-constraints --witness-lib ../zisk/target/release/libzisk_witness.so --rom $ELF_FILE -i $INPUT_FILE --proving-key ../zisk/build/provingKey) + fi done diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 9e3834a5..8d89c2dc 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -51,7 +51,7 @@ impl ZiskExecutor { let rom_sm = RomSM::new(wcm.clone()); let mem_sm = MemSM::new(wcm.clone()); let binary_sm = BinarySM::new(wcm.clone(), std.clone()); - let arith_sm = ArithSM::new(wcm.clone()); + let arith_sm = ArithSM::new(wcm.clone(), binary_sm.clone()); // If rom_path has an .elf extension it must be converted to a ZisK ROM let zisk_rom = if rom_path.extension().unwrap() == "elf" { From a7b1300016c591f47057ac14ad733177cfe46d8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip=20Ardevol?= Date: Wed, 11 Dec 2024 14:13:12 +0100 Subject: [PATCH 3/6] Cleaning the binary component (#182) * Executor and pil done * Update binary_basic.rs * name change * traces update --- pil/src/pil_helpers/traces.rs | 4 +-- state-machines/binary/pil/binary.pil | 4 +-- .../binary/pil/binary_extension.pil | 4 +-- state-machines/binary/src/binary_basic.rs | 29 +------------------ state-machines/binary/src/binary_extension.rs | 2 +- 5 files changed, 8 insertions(+), 35 deletions(-) diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 4b09b133..47541a74 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -24,7 +24,7 @@ trace!(ArithRangeTableRow, ArithRangeTableTrace { }); trace!(BinaryRow, BinaryTrace { - m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, has_initial_carry: F, cout: F, result_is_a: F, use_last_carry_mode32: F, use_last_carry_mode64: F, m_op_or_ext: F, free_in_a_or_c: [F; 4], free_in_b_or_zero: [F; 4], multiplicity: F, main_step: F, + m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, has_initial_carry: F, cout: F, result_is_a: F, use_last_carry_mode32: F, use_last_carry_mode64: F, m_op_or_ext: F, free_in_a_or_c: [F; 4], free_in_b_or_zero: [F; 4], multiplicity: F, debug_main_step: F, }); trace!(BinaryTableRow, BinaryTableTrace { @@ -32,7 +32,7 @@ trace!(BinaryTableRow, BinaryTableTrace { }); trace!(BinaryExtensionRow, BinaryExtensionTrace { - op: F, in1: [F; 8], in2_low: F, out: [[F; 2]; 8], op_is_shift: F, in2: [F; 2], main_step: F, multiplicity: F, + op: F, in1: [F; 8], in2_low: F, out: [[F; 2]; 8], op_is_shift: F, in2: [F; 2], debug_main_step: F, multiplicity: F, }); trace!(BinaryExtensionTableRow, BinaryExtensionTableTrace { diff --git a/state-machines/binary/pil/binary.pil b/state-machines/binary/pil/binary.pil index 8d8d5634..29e2c731 100644 --- a/state-machines/binary/pil/binary.pil +++ b/state-machines/binary/pil/binary.pil @@ -171,6 +171,6 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id) { expr op = m_op + 0x20 * mode32; col witness multiplicity; - col witness main_step; - lookup_proves(OPERATION_BUS_ID, [main_step, op, ...a, ...b, ...c, cout - result_is_a], multiplicity); + col witness debug_main_step; + lookup_proves(OPERATION_BUS_ID, [debug_main_step, op, ...a, ...b, ...c, cout - result_is_a], multiplicity); } \ No newline at end of file diff --git a/state-machines/binary/pil/binary_extension.pil b/state-machines/binary/pil/binary_extension.pil index c5d9fc4a..13681fc1 100644 --- a/state-machines/binary/pil/binary_extension.pil +++ b/state-machines/binary/pil/binary_extension.pil @@ -86,12 +86,12 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id) { expr in1_low = in1[0] + in1[1]*2**8 + in1[2]*2**16 + in1[3]*2**24; expr in1_high = in1[4] + in1[5]*2**8 + in1[6]*2**16 + in1[7]*2**24; - col witness main_step; + col witness debug_main_step; col witness multiplicity; lookup_proves( operation_bus_id, [ - main_step, + debug_main_step, op, op_is_shift * (in1_low - in2[0]) + in2[0], op_is_shift * (in1_high - in2[1]) + in2[1], diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index dccedaf5..aefe09f6 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -239,7 +239,7 @@ impl BinaryBasicSM { } // Set main SM step - row.main_step = F::from_canonical_u64(operation.step); + row.debug_main_step = F::from_canonical_u64(operation.step); // Set use last carry and carry[], based on operation let mut cout: u64; @@ -928,33 +928,6 @@ impl BinaryBasicSM { row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES]; } - // Set cout - let cout32 = row.carry[HALF_BYTES - 1]; - let cout64 = row.carry[BYTES - 1]; - row.cout = mode64 * (cout64 - cout32) + cout32; - - // Set result_is_a - row.result_is_a = row.op_is_min_max * row.cout; - - // Set use_last_carry_mode32 and use_last_carry_mode64 - row.use_last_carry_mode32 = F::from_bool(mode32) * row.use_last_carry; - row.use_last_carry_mode64 = mode64 * row.use_last_carry; - - // Set micro opcode - row.m_op = F::from_canonical_u8(binary_basic_table_op as u8); - - // Set m_op_or_ext - let ext_32_op = F::from_canonical_u8(BinaryBasicTableOp::Ext32 as u8); - row.m_op_or_ext = mode64 * (row.m_op - ext_32_op) + ext_32_op; - - // Set free_in_a_or_c and free_in_b_or_zero - for i in 0..HALF_BYTES { - row.free_in_a_or_c[i] = mode64 * - (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) + - row.free_in_c[HALF_BYTES - 1]; - row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES]; - } - if row.use_last_carry == F::one() { // Set first and last elements row.free_in_c[7] = row.free_in_c[0]; diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index 129b8202..ded8972f 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -191,7 +191,7 @@ impl BinaryExtensionSM { row.in2[1] = F::from_canonical_u64(in2_1); // Set main SM step - row.main_step = F::from_canonical_u64(operation.step); + row.debug_main_step = F::from_canonical_u64(operation.step); // Calculate the trace output let mut t_out: [[u64; 2]; 8] = [[0; 2]; 8]; From e5f11b23edeb3c9b913109be0609054040392605 Mon Sep 17 00:00:00 2001 From: fractasy <89866610+fractasy@users.noreply.github.com> Date: Wed, 11 Dec 2024 15:26:25 +0100 Subject: [PATCH 4/6] Development of mem SM: MemProxy, MemAlign, MemAlignRom, Mem, RomData, InputData. (#153) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First commit on memory_checkpoint3 * Create input_data_sm.rs * mem_proxy to little endian * update to proofman 0.0.10 * Fix errors * mem align working * Write RO data sections using code instructions * Fix riscof test by not writing RO sections with addr=0 and fixing vector index * Ignore ELF sections with addr=0 * Mem align fully working * Cleaning the mem align * Removing hashmaps and cleaning up stuff a little bit * mem - mem_align integration * fix invalid assert on mem_align * Remove store_c_slice to fix parallel execution without memory * Mem align and mem fixes (#175) * Removing unnecessary code * update mem executors with correct call to register_predecessor * Memory required extended data to build Memory SM proof * Feature/custom commits (#166) * Custom cols rom (#159) Custom cols working --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> * Cached custom commits * Updating proofman to 0.0.12 * Global constraints verifying again * Optimizing the binary component (#167) * Optimizing the binary * Updating the executor * Updating to 0.0.13 * Not creating unnecessary instances of arith tables * Pil2-proofman 0.0.14 --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Co-authored-by: Héctor Masip Ardevol * remove old input_data.pil * Zisk working with last proofman version * Updating book and Cargo.toml to point to 0.0.16 proofman * fix minor bug after update develop changes * fix bug on additional mem calculation * update direct_global continuations * update mem pil comments * fix last bugs on memory, added documentation --------- Co-authored-by: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Co-authored-by: Héctor Masip Co-authored-by: zkronos73 Co-authored-by: Roger Taulé Buxadera <55488871+RogerTaule@users.noreply.github.com> Co-authored-by: RogerTaule --- .gitignore | 6 +- Cargo.lock | 4 + core/src/elf2rom.rs | 59 +- core/src/mem.rs | 482 +++++++- core/src/riscv2zisk_context.rs | 50 +- core/src/zisk_ops.rs | 2 +- core/src/zisk_required_operation.rs | 62 +- emulator/src/emu.rs | 309 ++++- emulator/src/emu_context.rs | 24 +- emulator/src/emu_trace.rs | 16 +- emulator/src/emulator.rs | 20 +- pil/src/lib.rs | 3 - pil/src/pil_helpers/pilout.rs | 39 +- pil/src/pil_helpers/traces.rs | 30 +- pil/zisk.pil | 16 +- state-machines/arith/pil/arith_table.pil | 6 +- state-machines/arith/src/arith_full.rs | 3 +- state-machines/arith/src/arith_table.rs | 2 - .../arith/src/arith_table_helpers.rs | 4 +- .../binary/pil/binary_extension_table.pil | 22 +- state-machines/binary/src/binary_basic.rs | 2 +- state-machines/binary/src/binary_extension.rs | 5 +- state-machines/main/pil/main.pil | 42 +- state-machines/main/src/main_sm.rs | 47 +- state-machines/mem/Cargo.toml | 11 +- state-machines/mem/pil/mem.pil | 265 +++-- state-machines/mem/pil/mem_align.pil | 188 +++ state-machines/mem/pil/mem_align_rom.pil | 323 ++++++ state-machines/mem/src/input_data_sm.rs | 377 ++++++ state-machines/mem/src/lib.rs | 30 +- state-machines/mem/src/mem.rs | 101 -- state-machines/mem/src/mem_align_rom_sm.rs | 214 ++++ state-machines/mem/src/mem_align_sm.rs | 1015 +++++++++++++++++ state-machines/mem/src/mem_aligned.rs | 112 -- state-machines/mem/src/mem_constants.rs | 21 +- state-machines/mem/src/mem_helpers.rs | 106 +- state-machines/mem/src/mem_module.rs | 31 + state-machines/mem/src/mem_proxy.rs | 79 ++ state-machines/mem/src/mem_proxy_engine.rs | 628 ++++++++++ state-machines/mem/src/mem_sm.rs | 383 +++++++ state-machines/mem/src/mem_traces.rs | 5 - state-machines/mem/src/mem_unaligned.rs | 114 -- state-machines/mem/src/mem_unmapped.rs | 35 + state-machines/mem/src/rom_data.rs | 339 ++++++ witness-computation/src/executor.rs | 113 +- 45 files changed, 5082 insertions(+), 663 deletions(-) create mode 100644 state-machines/mem/pil/mem_align_rom.pil create mode 100644 state-machines/mem/src/input_data_sm.rs delete mode 100644 state-machines/mem/src/mem.rs create mode 100644 state-machines/mem/src/mem_align_rom_sm.rs create mode 100644 state-machines/mem/src/mem_align_sm.rs delete mode 100644 state-machines/mem/src/mem_aligned.rs create mode 100644 state-machines/mem/src/mem_module.rs create mode 100644 state-machines/mem/src/mem_proxy.rs create mode 100644 state-machines/mem/src/mem_proxy_engine.rs create mode 100644 state-machines/mem/src/mem_sm.rs delete mode 100644 state-machines/mem/src/mem_traces.rs delete mode 100644 state-machines/mem/src/mem_unaligned.rs create mode 100644 state-machines/mem/src/mem_unmapped.rs create mode 100644 state-machines/mem/src/rom_data.rs diff --git a/.gitignore b/.gitignore index 63d76060..70c16c23 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,8 @@ /target /*.tar.gz /riscof -/build -/proofs +/build* +/proofs* *.pilout /tmp -*.log \ No newline at end of file +*.log diff --git a/Cargo.lock b/Cargo.lock index e59ff2b5..cd9341db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2238,10 +2238,14 @@ name = "sm-mem" version = "0.1.0" dependencies = [ "log", + "num-bigint", + "num-traits", "p3-field", + "pil-std-lib", "proofman", "proofman-common", "proofman-macros", + "proofman-util", "rayon", "sm-common", "zisk-core", diff --git a/core/src/elf2rom.rs b/core/src/elf2rom.rs index 668c8aee..0209aa27 100644 --- a/core/src/elf2rom.rs +++ b/core/src/elf2rom.rs @@ -28,7 +28,15 @@ pub fn elf2rom(elf_file: String) -> Result> { for section_header in section_headers { // Consider only the section headers that contain program data if section_header.sh_type == SHT_PROGBITS { - // Get the program section data as a vector of bytes + // Get the section header address + let addr = section_header.sh_addr; + + // Ignore sections with address = 0, as per ELF spec + if addr == 0 { + continue; + } + + // Get the section data let (data_u8, _) = elf_bytes.section_data(§ion_header)?; let mut data = data_u8.to_vec(); @@ -37,31 +45,58 @@ pub fn elf2rom(elf_file: String) -> Result> { data.pop(); } - // Get the section data address - let addr = section_header.sh_addr; - - // If the data contains instructions, parse them as RISC-V instructions and add them - // to the ROM instructions, at the specified program address + // If this is a code section, add it to program if (section_header.sh_flags & SHF_EXECINSTR as u64) != 0 { add_zisk_code(&mut rom, addr, &data); } + // Add init data as a read/write memory section, initialized by code // If the data is a writable memory section, add it to the ROM memory using Zisk // copy instructions if (section_header.sh_flags & SHF_WRITE as u64) != 0 && addr >= RAM_ADDR && addr + data.len() as u64 <= RAM_ADDR + RAM_SIZE { - add_zisk_init_data(&mut rom, addr, &data); - // Otherwise, add it to the ROM as RO data - } else { - rom.ro_data.push(RoData::new(addr, data.len(), data)); + //println! {"elf2rom() new RW from={:x} length={:x}={}", addr, data.len(), + //data.len()}; + add_zisk_init_data(&mut rom, addr, &data, true); + } + // Add read-only data memory section + else { + // Search for an existing RO section previous to this one + let mut found = false; + for rd in rom.ro_data.iter_mut() { + // Section data should be previous to this one + if (rd.from + rd.length as u64) == addr { + rd.length += data.len(); + rd.data.extend(data.clone()); + found = true; + //println! {"elf2rom() adding RO from={:x} length={:x}={}", rd.from, + // rd.length, rd.length}; + break; + } + } + + // If not found, create a new RO section + if !found { + //println! {"elf2rom() new RO from={:x} length={:x}={}", addr, data.len(), + // data.len()}; + rom.ro_data.push(RoData::new(addr, data.len(), data)); + } } } } } - // Add the program setup, system call and program wrapup instructions + // Add RO data initialization code insctructions + let ro_data_len = rom.ro_data.len(); + for i in 0..ro_data_len { + let addr = rom.ro_data[i].from; + let mut data = Vec::new(); + data.extend(rom.ro_data[i].data.as_slice()); + add_zisk_init_data(&mut rom, addr, &data, true); + } + add_entry_exit_jmp(&mut rom, elf_bytes.ehdr.e_entry); // Preprocess the ROM (experimental) @@ -128,6 +163,8 @@ pub fn elf2rom(elf_file: String) -> Result> { } } + //println! {"elf2rom() got rom.insts.len={}", rom.insts.len()}; + Ok(rom) } diff --git a/core/src/mem.rs b/core/src/mem.rs index f5febf93..10c1d719 100644 --- a/core/src/mem.rs +++ b/core/src/mem.rs @@ -5,49 +5,48 @@ //! * The Zisk processor memory stores data in little-endian format. //! * The addressable memory space is divided into several regions described in the following map: //! -//! `|--------------- ROM_ENTRY: first BIOS instruction ( 0x1000)` -//! `|` -//! `| Performs memory initialization, calls program at ROM_ADDR,` -//! `| and after returning it performs memory finalization.` -//! `| Contains ecall/system call management code.` -//! `|` -//! `|--------------- ROM_EXIT: last BIOS instruction (0x10000000)` -//! ` ...` -//! `|--------------- ROM_ADDR: first program instruction (0x80000000)` -//! `|` -//! `| Contains program instructions.` -//! `| Calls ecalls/system calls when required.` -//! `|` -//! `|--------------- INPUT_ADDR (0x90000000)` -//! `|` -//! `| Contains program input data.` -//! `|` -//! `|--------------- SYS_ADDR (= RAM_ADDR = REG_FIRST) (0xa0000000)` -//! `|` -//! `| Contains system address.` -//! `| The first 256 bytes contain 32 8-byte registers` -//! `| The address UART_ADDR is used as a standard output` -//! `|` -//! `|--------------- OUTPUT_ADDR (0xa0010000)` -//! `|` -//! `| Contains output data, which is written during` -//! `| program execution and read during memory finalization` -//! `|` -//! `|--------------- AVAILABLE_MEM_ADDR (0xa0020000)` -//! `|` -//! `| Contains program memory, available for normal R/W` -//! `| use during program execution.` -//! `|` -//! `|--------------- (0xb0000000)` -//! ` ...` +//! `|--------------- ROM_ENTRY: first BIOS instruction ( 0x1000)` +//! `|` +//! `| Performs memory initialization, calls program at ROM_ADDR,` +//! `| and after returning it performs memory finalization.` +//! `| Contains ecall/system call management code.` +//! `|` +//! `|--------------- ROM_EXIT: last BIOS instruction (0x10000000)` +//! ` ...` +//! `|--------------- ROM_ADDR: first program instruction (0x80000000)` +//! `|` +//! `| Contains program instructions.` +//! `| Calls ecalls/system calls when required.` +//! `|` +//! `|--------------- INPUT_ADDR (0x90000000)` +//! `|` +//! `| Contains program input data.` +//! `|` +//! `|--------------- SYS_ADDR (= RAM_ADDR = REG_FIRST) (0xa0000000)` +//! `|` +//! `| Contains system address.` +//! `| The first 256 bytes contain 32 8-byte registers` +//! `| The address UART_ADDR is used as a standard output` +//! `|` +//! `|--------------- OUTPUT_ADDR (0xa0010000)` +//! `|` +//! `| Contains output data, which is written during` +//! `| program execution and read during memory finalization` +//! `|` +//! `|--------------- AVAILABLE_MEM_ADDR (0xa0020000)` +//! `|` +//! `| Contains program memory, available for normal R/W` +//! `| use during program execution.` +//! `|` +//! `|--------------- (0xb0000000)` +//! ` ...` //! //! ## ROM_ENTRY / ROM_ADDR / ROM_EXIT //! * The program will start executing at the first BIOS address `ROM_ENTRY`. //! * The first instructions do the basic program setup, including writing the input data into //! memory, configuring the ecall (system call) program address, and configuring the program //! completion return address. -//! * After the program setup, the program counter jumps to `ROM_ADDR`, executing the actual -//! program. +//! * After the program set1, the program counter jumps to `ROM_ADDR`, executing the actual program. //! * During the execution, the program can make system calls that will jump to the configured ecall //! program address, and return once the task has completed. The precompiled are implemented via //! ecall. @@ -79,14 +78,16 @@ //! * The third RW memory region going from `AVAILABLE_MEM_ADDR` onwards can be used during the //! program execution a general purpose memory. +use std::fmt; + /// Fist input data memory address pub const INPUT_ADDR: u64 = 0x90000000; /// Maximum size of the input data -pub const MAX_INPUT_SIZE: u64 = 0x10000000; // 256M, +pub const MAX_INPUT_SIZE: u64 = 0x08000000; // 128M, /// First globa RW memory address pub const RAM_ADDR: u64 = 0xa0000000; /// Size of the global RW memory -pub const RAM_SIZE: u64 = 0x10000000; // 256M +pub const RAM_SIZE: u64 = 0x08000000; // 128M /// First system RW memory address pub const SYS_ADDR: u64 = RAM_ADDR; /// Size of the system RW memory @@ -106,7 +107,7 @@ pub const ROM_EXIT: u64 = 0x10000000; /// First program ROM instruction address, i.e. first RISC-V transpiled instruction pub const ROM_ADDR: u64 = 0x80000000; /// Maximum program ROM instruction address -pub const ROM_ADDR_MAX: u64 = INPUT_ADDR - 1; +pub const ROM_ADDR_MAX: u64 = (ROM_ADDR + 0x08000000) - 1; // 128M /// Zisk architecture ID pub const ARCH_ID_ZISK: u64 = 0xFFFEEEE; /// UART memory address; single bytes written here will be copied to the standard output @@ -114,13 +115,46 @@ pub const UART_ADDR: u64 = SYS_ADDR + 512; /// Memory section data, including a buffer (a vector of bytes) and start and end program /// memory addresses. -#[derive(Default)] pub struct MemSection { pub start: u64, pub end: u64, + pub real_end: u64, pub buffer: Vec, } +/// Default constructor for MemSection structure +impl Default for MemSection { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for MemSection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(&self.to_text()) + } +} + +/// Memory section structure implementation +impl MemSection { + /// Memory section constructor + pub fn new() -> MemSection { + MemSection { start: 0, end: 0, real_end: 0, buffer: Vec::new() } + } + pub fn to_text(&self) -> String { + format!( + "start={:x} real_end={:x} end={:x} diff={:x}={} buffer.len={:x}={}", + self.start, + self.real_end, + self.end, + self.end - self.start, + self.end - self.start, + self.buffer.len(), + self.buffer.len() + ) + } +} + /// Memory structure, containing several read sections and one single write section #[derive(Default)] pub struct Mem { @@ -129,6 +163,12 @@ pub struct Mem { } impl Mem { + /// Memory structue constructor + pub fn new() -> Mem { + //println!("Mem::new()"); + Mem { read_sections: Vec::new(), write_section: MemSection::new() } + } + /// Adds a read section to the memory structure pub fn add_read_section(&mut self, start: u64, buffer: &[u8]) { // Check that the start address is alligned to 8 bytes @@ -142,31 +182,60 @@ impl Mem { // Calculate the end address let end = start + buffer.len() as u64; - // Create a mem section with this data - let mut mem_section = MemSection { start, end, buffer: buffer.to_owned() }; + // If there exists a read section next to this one, reuse it + for existing_section in self.read_sections.iter_mut() { + if existing_section.real_end == start { + // Sanity check + assert!(existing_section.real_end <= existing_section.end); + assert!((existing_section.end - existing_section.real_end) < 8); + + // Pop tail zeros until end matches real_end + while existing_section.real_end > existing_section.end { + existing_section.buffer.pop(); + existing_section.end -= 1; + } + + // Append buffer + existing_section.buffer.extend(buffer); + existing_section.real_end += buffer.len() as u64; + existing_section.end = existing_section.real_end; + + // Append zeros until end is multiple of 8, so that we can read non-alligned reads + while (existing_section.end & 0x07) != 0 { + existing_section.buffer.push(0); + existing_section.end += 1; + } + + /*println!( + "Mem::add_read_section() start={:x} len={} existing section={}", + start, + buffer.len(), + existing_section.to_text() + );*/ + + return; + } + } + + // Create a new memory section + let mut new_section = MemSection { start, end, real_end: end, buffer: buffer.to_owned() }; - // Add zero-value bytes until the end address is alligned to 8 bytes - while (mem_section.end) % 8 != 0 { - mem_section.buffer.push(0); - mem_section.end += 1; + // Append zeros until end is multiple of 8, so that we can read non-alligned reads + while (new_section.end & 0x07) != 0 { + new_section.buffer.push(0); + new_section.end += 1; } - // Push the new read section to the read sections list - self.read_sections.push(mem_section); - - /*println!( - "Mem::add_read_section() start={:x}={} len={} end={:x}={}", - start, - start, - buffer.len(), - end, - end - );*/ + //println!("Mem::add_read_section() new section={}", new_section.to_text()); + + // Add the new section to the read sections + self.read_sections.push(new_section); } /// Adds a write section to the memory structure, which cannot be written twice pub fn add_write_section(&mut self, start: u64, size: u64) { - //println!("Mem::add_write_section() start={} size={}", start, size); + //println!("Mem::add_write_section() start={:x}={} size={:x}={}", start, start, size, + // size); // Check that the start address is alligned to 8 bytes if (start & 0x07) != 0 { @@ -262,6 +331,177 @@ impl Mem { } } + /* + Possible alignment situations: + - Full aligned = address is aligned to 8 bytes (last 3 bits are zero) and width is 8 + - Single not aligned = not full aligned, and the data fits into one aligned slice of 8 bytes + - Double not aligned = not full aligned, and the data needs 2 aligned slices of 8 bytes + + Data required for each situation: + - full_aligned + RD = value + - full_aligned + WR = value, full_value + - single_not_aligned + RD = value, full_value TODO: We can save the value space, optimization + - single_not_aligned + WR = value, previous_full_value + - double_not_aligned + RD = value, full_values_0, full_values_1 + - double_not_aligned + WR = value, previous_full_values_0, previous_full_values_1 + + read_required() returns read value, and a vector of additional data required to prove it + */ + + /// Read a u64 value from the memory read sections, based on the provided address and width + #[inline(always)] + pub fn read_required(&self, addr: u64, width: u64) -> (u64, Vec) { + // Calculate how aligned this operation is + let addr_req_1 = addr & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the first 8-bytes chunk + let addr_req_2 = (addr + width - 1) & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the second 8-bytes chunk, if needed + let is_full_aligned = ((addr & 0x03) == 0) && (width == 8); + let is_single_not_aligned = !is_full_aligned && (addr_req_1 == addr_req_2); + let is_double_not_aligned = !is_full_aligned && !is_single_not_aligned; + + // First try to read in the write section + if (addr >= self.write_section.start) && (addr <= (self.write_section.end - width)) { + // Calculate the read position + let read_position: usize = (addr - self.write_section.start) as usize; + + // Read the requested data based on the provided width + let value: u64 = match width { + 1 => self.write_section.buffer[read_position] as u64, + 2 => u16::from_le_bytes( + self.write_section.buffer[read_position..read_position + 2].try_into().unwrap(), + ) as u64, + 4 => u32::from_le_bytes( + self.write_section.buffer[read_position..read_position + 4].try_into().unwrap(), + ) as u64, + 8 => u64::from_le_bytes( + self.write_section.buffer[read_position..read_position + 8].try_into().unwrap(), + ), + _ => panic!("Mem::read() invalid width={}", width), + }; + + // If is a single not aligned operation, return the aligned address value + if is_single_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= self.write_section.start); + let read_position_req: usize = (addr_req_1 - self.write_section.start) as usize; + let value_req = u64::from_le_bytes( + self.write_section.buffer[read_position_req..read_position_req + 8] + .try_into() + .unwrap(), + ); + additional_data.push(value_req); + + return (value, additional_data); + } + + // If is a double not aligned operation, return the aligned address value and the next + // one + if is_double_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= self.write_section.start); + let read_position_req_1: usize = (addr_req_1 - self.write_section.start) as usize; + let value_req_1 = u64::from_le_bytes( + self.write_section.buffer[read_position_req_1..read_position_req_1 + 8] + .try_into() + .unwrap(), + ); + additional_data.push(value_req_1); + + assert!(addr_req_2 >= self.write_section.start); + let read_position_req_2: usize = (addr_req_2 - self.write_section.start) as usize; + let value_req_2 = u64::from_le_bytes( + self.write_section.buffer[read_position_req_2..read_position_req_2 + 8] + .try_into() + .unwrap(), + ); + additional_data.push(value_req_2); + + return (value, additional_data); + } + + //println!("Mem::read() addr={:x} width={} value={:x}={}", addr, width, value, value); + return (value, Vec::new()); + } + + // Search for the section that contains the address using binary search (dicothomic search) + let section = if let Ok(section) = self.read_sections.binary_search_by(|section| { + if addr < section.start { + std::cmp::Ordering::Greater + } else if (addr + width) > section.end { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + }) { + &self.read_sections[section] + } else { + println!("sections: {:?}", self.read_sections); + panic!("Mem::read() section not found for addr: {} with width: {}", addr, width); + }; + + // Calculate the read position + let read_position: usize = (addr - section.start) as usize; + + // Read the requested data based on the provided width + let value: u64 = match width { + 1 => section.buffer[read_position] as u64, + 2 => u16::from_le_bytes( + section.buffer[read_position..read_position + 2].try_into().unwrap(), + ) as u64, + 4 => u32::from_le_bytes( + section.buffer[read_position..read_position + 4].try_into().unwrap(), + ) as u64, + 8 => u64::from_le_bytes( + section.buffer[read_position..read_position + 8].try_into().unwrap(), + ), + _ => panic!( + "Mem::read() invalid addr:0x{:X} read_position:{} width:{}", + addr, read_position, width + ), + }; + + // If is a single not aligned operation, return the aligned address value + if is_single_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= section.start); + let read_position_req: usize = (addr_req_1 - section.start) as usize; + let value_req = u64::from_le_bytes( + section.buffer[read_position_req..read_position_req + 8].try_into().unwrap(), + ); + additional_data.push(value_req); + + return (value, additional_data); + } + + // If is a double not aligned operation, return the aligned address value and the next + // one + if is_double_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= section.start); + let read_position_req_1: usize = (addr_req_1 - section.start) as usize; + let value_req_1 = u64::from_le_bytes( + section.buffer[read_position_req_1..read_position_req_1 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_1); + + assert!(addr_req_2 >= section.start); + let read_position_req_2: usize = (addr_req_2 - section.start) as usize; + let value_req_2 = u64::from_le_bytes( + section.buffer[read_position_req_2..read_position_req_2 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_2); + + return (value, additional_data); + } + + //println!("Mem::read() addr={:x} width={} value={:x}={}", addr, width, value, value); + + (value, Vec::new()) + } + /// Write a u64 value to the memory write section, based on the provided address and width #[inline(always)] pub fn write(&mut self, addr: u64, val: u64, width: u64) { @@ -280,8 +520,24 @@ impl Mem { //println!("Mem::write() addr={:x}={} width={} value={:x}={}", addr, addr, width, val, // val); - // Get a reference to the write section - let section = &mut self.write_section; + // Search for the section that contains the address using binary search (dicothomic search) + let section = if let Ok(section) = self.read_sections.binary_search_by(|section| { + if addr < section.start { + std::cmp::Ordering::Greater + } else if addr > (section.end - width) { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + }) { + &mut self.read_sections[section] + } else { + /*panic!( + "Mem::write_silent() section not found for addr={:x}={} with width: {}", + addr, addr, width + );*/ + &mut self.write_section + }; // Check that the address and width fall into this section address range if (addr < section.start) || ((addr + width) > section.end) { @@ -304,6 +560,110 @@ impl Mem { 8 => section.buffer[write_position..write_position + 8] .copy_from_slice(&val.to_le_bytes()), _ => panic!("Mem::write_silent() invalid width={}", width), + }; + } + + /// Write a u64 value to the memory write section, based on the provided address and width + #[inline(always)] + pub fn write_silent_required(&mut self, addr: u64, val: u64, width: u64) -> Vec { + //println!("Mem::write() addr={:x}={} width={} value={:x}={}", addr, addr, width, val, + // val); + + // Search for the section that contains the address using binary search (dicothomic search) + let section = if let Ok(section) = self.read_sections.binary_search_by(|section| { + if addr < section.start { + std::cmp::Ordering::Greater + } else if addr > (section.end - width) { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + }) { + &mut self.read_sections[section] + } else { + /*panic!( + "Mem::write_silent() section not found for addr={:x}={} with width: {}", + addr, addr, width + );*/ + &mut self.write_section + }; + + // Check that the address and width fall into this section address range + if (addr < section.start) || ((addr + width) > section.end) { + panic!( + "Mem::write_silent() invalid addr={}={:x} write section start={:x} end={:x}", + addr, addr, section.start, section.end + ); + } + + // Calculate how aligned this operation is + let addr_req_1 = addr & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the first 8-bytes chunk + let addr_req_2 = (addr + width - 1) & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the second 8-bytes chunk, if needed + let is_full_aligned = ((addr & 0x03) == 0) && (width == 8); + let is_single_not_aligned = !is_full_aligned && (addr_req_1 == addr_req_2); + let is_double_not_aligned = !is_full_aligned && !is_single_not_aligned; + + // Declare an empty vector + let mut additional_data: Vec = Vec::new(); + + // If is a single not aligned operation, return the aligned address value + if is_single_not_aligned { + assert!( + addr_req_1 >= section.start, + "addr_req_1: 0x{:X} 0x{:X}]", + addr_req_1, + section.start + ); + let read_position_req: usize = (addr_req_1 - section.start) as usize; + let value_req = u64::from_le_bytes( + section.buffer[read_position_req..read_position_req + 8].try_into().unwrap(), + ); + additional_data.push(value_req); } + + // If is a double not aligned operation, return the aligned address value and the next + // one + if is_double_not_aligned { + assert!( + addr_req_1 >= section.start, + "addr_req_1(d): 0x{:X} 0x{:X}]", + addr_req_1, + section.start + ); + let read_position_req_1: usize = (addr_req_1 - section.start) as usize; + let value_req_1 = u64::from_le_bytes( + section.buffer[read_position_req_1..read_position_req_1 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_1); + + assert!( + addr_req_2 >= section.start, + "addr_req_2(d): 0x{:X} 0x{:X}]", + addr_req_2, + section.start + ); + let read_position_req_2: usize = (addr_req_2 - section.start) as usize; + let value_req_2 = u64::from_le_bytes( + section.buffer[read_position_req_2..read_position_req_2 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_2); + } + + // Calculate the write position + let write_position: usize = (addr - section.start) as usize; + + // Write the value based on the provided width + match width { + 1 => section.buffer[write_position] = val as u8, + 2 => section.buffer[write_position..write_position + 2] + .copy_from_slice(&(val as u16).to_le_bytes()), + 4 => section.buffer[write_position..write_position + 4] + .copy_from_slice(&(val as u32).to_le_bytes()), + 8 => section.buffer[write_position..write_position + 8] + .copy_from_slice(&val.to_le_bytes()), + _ => panic!("Mem::write_silent() invalid width={}", width), + } + + additional_data } } diff --git a/core/src/riscv2zisk_context.rs b/core/src/riscv2zisk_context.rs index fd0ab47a..d9a12b4d 100644 --- a/core/src/riscv2zisk_context.rs +++ b/core/src/riscv2zisk_context.rs @@ -1308,8 +1308,13 @@ pub fn add_zisk_code(rom: &mut ZiskRom, addr: u64, data: &[u8]) { /// /// The initial data is copied in chunks of 8 bytes for efficiency, until less than 8 bytes are left /// to copy. The remaining bytes are copied in additional chunks of 4, 2 and 1 byte, if required. -pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { - //print!("add_zisk_init_data() addr={}\n", addr); +pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8], force_aligned: bool) { + /*let mut s = String::new(); + for i in 0..min(50, data.len()) { + s += &format!("{:02x}", data[i]); + } + print!("add_zisk_init_data() addr={:x} len={} data={}...\n", addr, data.len(), s);*/ + let mut o = addr; // Read 64-bit input data chunks and store them in rom @@ -1330,6 +1335,29 @@ pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { o += 8; } + // TODO: review if necessary + let bytes = addr + data.len() as u64 - o; + // If force_aligned is active always store aligned + if force_aligned && bytes > 0 { + let mut v: u64 = 0; + let from = (o - addr + bytes - 1) as usize; + for i in 0..bytes { + v = v * 256 + data[from - i as usize] as u64; + } + let mut zib = ZiskInstBuilder::new(rom.next_init_inst_addr); + zib.src_a("imm", o, false); + zib.src_b("imm", v, false); + zib.op("copyb").unwrap(); + zib.ind_width(8); + zib.store("ind", 0, false, false); + zib.j(4, 4); + zib.verbose(&format!("Init Data {:08x}: {:04x}", o, v)); + zib.build(); + rom.insts.insert(rom.next_init_inst_addr, zib); + rom.next_init_inst_addr += 4; + o += bytes; + } + // Read remaining 32-bit input data chunk, if any, and store them in rom if addr + data.len() as u64 - o >= 4 { let v = u32::from_le_bytes(data[o as usize..o as usize + 4].try_into().unwrap()); @@ -1366,7 +1394,7 @@ pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { // Read remaining 8-bit input data chunk, if any, and store them in rom if addr + data.len() as u64 - o >= 1 { - let v = data[o as usize]; + let v = data[(o - addr) as usize]; let mut zib = ZiskInstBuilder::new(rom.next_init_inst_addr); zib.src_a("imm", o, false); zib.src_b("imm", v as u64, false); @@ -1380,7 +1408,21 @@ pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { rom.next_init_inst_addr += 4; o += 1; } - + /* + if force_aligned { + let mut zib = ZiskInstBuilder::new(rom.next_init_inst_addr); + zib.src_a("imm", o, false); + zib.src_b("imm", 0, false); + zib.op("copyb").unwrap(); + zib.ind_width(8); + zib.store("ind", 0, false, false); + zib.j(4, 4); + zib.verbose(&format!("Init Data {:08x}: {:04x}", o, 0)); + zib.build(); + rom.insts.insert(rom.next_init_inst_addr, zib); + rom.next_init_inst_addr += 4; + } + */ // Check resulting length if o != addr + data.len() as u64 { panic!("add_zisk_init_data() invalid length o={} addr={} data.len={}", o, addr, data.len()); diff --git a/core/src/zisk_ops.rs b/core/src/zisk_ops.rs index 71efe57f..c9e23c10 100644 --- a/core/src/zisk_ops.rs +++ b/core/src/zisk_ops.rs @@ -284,7 +284,7 @@ define_ops! { (MaxuW, "maxu_w", Binary, 77, 0x24, opc_maxu_w, op_maxu_w), (MaxW, "max_w", Binary, 77, 0x25, opc_max_w, op_max_w), (Keccak, "keccak", Keccak, 77, 0xf1, opc_keccak, op_keccak), - (PubOut, "pubout", PubOut, 77, 0x30, opc_pubout, op_pubout), // TODO: New type + (PubOut, "pubout", PubOut, 77, 0x30, opc_pubout, op_pubout), } /* INTERNAL operations */ diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index d82644da..81078f5b 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -1,14 +1,14 @@ //! Data required to prove the different Zisk operations -use std::collections::HashMap; +use std::{collections::HashMap, fmt}; -/// Required data to make an operation. +/// Required data to make an operation. /// /// Stores the minimum information to reproduce an operation execution: /// * The opcode and the a and b registers values (regardless of their sources) /// * The step is also stored to keep track of the program execution point /// -/// This data is generated during the first emulation execution. +/// This data is generated during the first emulation execution. /// This data is required by the main state machine executor to generate the witness computation. #[derive(Clone)] pub struct ZiskRequiredOperation { @@ -20,12 +20,52 @@ pub struct ZiskRequiredOperation { /// Stores the minimum information to generate the memory state machine witness computation. #[derive(Clone)] -pub struct ZiskRequiredMemory { - pub step: u64, - pub is_write: bool, - pub address: u64, - pub width: u64, - pub value: u64, +pub enum ZiskRequiredMemory { + Basic { step: u64, value: u64, address: u32, is_write: bool, width: u8, step_offset: u8 }, + Extended { values: [u64; 2], address: u32 }, +} + +impl fmt::Debug for ZiskRequiredMemory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ZiskRequiredMemory::Basic { step, value, address, is_write, width, step_offset: _ } => { + let label = if *is_write { "WR" } else { "RD" }; + write!( + f, + "{0} addr:{1:#08X}({1}) offset:{5} width:{2} value:{3:#016X}({3}) step:{4}", + label, + address, + width, + value, + step, + address & 0x07 + ) + } + ZiskRequiredMemory::Extended { values, address } => { + write!( + f, + "addr:{1:#08X}({0}) value[1]:{1} value[2]:{2}", + address, values[0], values[1], + ) + } + } + } +} + +impl ZiskRequiredMemory { + pub fn get_address(&self) -> u32 { + match self { + ZiskRequiredMemory::Basic { + step: _, + value: _, + address, + is_write: _, + width: _, + step_offset: _, + } => *address, + ZiskRequiredMemory::Extended { values: _, address } => *address, + } + } } /// Data required to get some operations proven by the secondary state machine @@ -37,9 +77,9 @@ pub struct ZiskRequired { pub memory: Vec, } -/// Histogram of the program counter values used during the program execution. +/// Histogram of the program counter values used during the program execution. /// -/// Each pc value has a u64 counter, associated to it via a hash map. +/// Each pc value has a u64 counter, associated to it via a hash map. /// The counter is increased every time the corresponding instruction is executed. #[derive(Clone, Default)] pub struct ZiskPcHistogram { diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index a5f78a6a..6ec7fa0f 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -9,9 +9,9 @@ use riscv::RiscVRegisters; // #[cfg(feature = "sp")] // use zisk_core::SRC_SP; use zisk_core::{ - InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, SRC_STEP, STORE_IND, STORE_MEM, - STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, + InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, + ZiskRequiredOperation, ZiskRom, OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, + SRC_STEP, STORE_IND, STORE_MEM, STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, }; /// ZisK emulator structure, containing the ZisK rom, the list of ZisK operations, and the @@ -29,8 +29,8 @@ pub struct Emu<'a> { /// - run -> step -> source_a, source_b, store_c (full functionality, called by main state machine, /// calls callback with trace) /// - run -> run_fast -> step_fast -> source_a, source_b, store_c (maximum speed, for benchmarking) -/// - run_slice -> step_slice -> source_a_slice, source_b_slice, store_c_slice (generates full trace -/// and required input data for secondary state machines) +/// - run_slice -> step_slice -> source_a_slice, source_b_slice (generates full trace and required +/// input data for secondary state machines) impl<'a> Emu<'a> { pub fn new(rom: &ZiskRom) -> Emu { Emu { rom, ctx: EmuContext::default() } @@ -92,6 +92,62 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'a' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_a_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { + match instruction.a_src { + SRC_C => self.ctx.inst_ctx.a = self.ctx.inst_ctx.c, + SRC_MEM => { + // Build the memory address + let mut addr = instruction.a_offset_imm0; + if instruction.a_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + + // Call read_required to get both the read value and the additional data (aligned + // read values required to construct the requested read value, if not aligned) + let additional_data: Vec; + (self.ctx.inst_ctx.a, additional_data) = + self.ctx.inst_ctx.mem.read_required(addr, 8); + + // Store the read value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 0, + is_write: false, + address: addr as u32, + width: 8, + value: self.ctx.inst_ctx.a, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } + } + SRC_IMM => { + self.ctx.inst_ctx.a = instruction.a_offset_imm0 | (instruction.a_use_sp_imm1 << 32) + } + SRC_STEP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.step, + // #[cfg(feature = "sp")] + // SRC_SP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.sp, + _ => panic!( + "Emu::source_a() Invalid a_src={} pc={}", + instruction.a_src, self.ctx.inst_ctx.pc + ), + } + } + /// Calculate the 'b' register value based on the source specified by the current instruction #[inline(always)] pub fn source_b(&mut self, instruction: &ZiskInst) { @@ -128,6 +184,94 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'b' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_b_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { + match instruction.b_src { + SRC_C => self.ctx.inst_ctx.b = self.ctx.inst_ctx.c, + SRC_MEM => { + // Build the memory address + let mut addr = instruction.b_offset_imm0; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + + // Call read_required to get both the read value and the additional data (aligned + // read values required to construct the requested read value, if not aligned) + let additional_data: Vec; + (self.ctx.inst_ctx.b, additional_data) = + self.ctx.inst_ctx.mem.read_required(addr, 8); + + // Store the read value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 1, + is_write: false, + address: addr as u32, + width: 8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } + } + SRC_IMM => { + self.ctx.inst_ctx.b = instruction.b_offset_imm0 | (instruction.b_use_sp_imm1 << 32) + } + SRC_IND => { + // Build the memory address + let mut addr = + (self.ctx.inst_ctx.a as i64 + instruction.b_offset_imm0 as i64) as u64; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + + // Call read_required to get both the read value and the additional data (aligned + // read values required to construct the requested read value, if not aligned) + let additional_data: Vec; + (self.ctx.inst_ctx.b, additional_data) = + self.ctx.inst_ctx.mem.read_required(addr, instruction.ind_width); + + // Store the read value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 1, + is_write: false, + address: addr as u32, + width: instruction.ind_width as u8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } + } + _ => panic!( + "Emu::source_b() Invalid b_src={} pc={}", + instruction.b_src, self.ctx.inst_ctx.pc + ), + } + } + /// Store the 'c' register value based on the storage specified by the current instruction #[inline(always)] pub fn store_c(&mut self, instruction: &ZiskInst) { @@ -171,45 +315,107 @@ impl<'a> Emu<'a> { } } - /// Store the 'c' register value based on the storage specified by the current instruction and - /// log memory access if required + /// Store the 'c' register value based on the storage specified by the current instruction #[inline(always)] - pub fn store_c_slice(&mut self, instruction: &ZiskInst) { + pub fn store_c_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { match instruction.store { STORE_NONE => {} STORE_MEM => { + // Calculate the value to write let val: i64 = if instruction.store_ra { self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 } else { self.ctx.inst_ctx.c as i64 }; + + // Build the memory address let mut addr: i64 = instruction.store_offset; if instruction.store_use_sp { addr += self.ctx.inst_ctx.sp as i64; } - self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, 8); + + // Call write_silent_required to get the additional data (aligned read values + // required to construct the new written data, if not aligned) + let additional_data: Vec = + self.ctx.inst_ctx.mem.write_silent_required(addr as u64, val as u64, 8); + + // Store the written value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 2, + is_write: true, + address: addr as u32, + width: 8, + value: val as u64, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } } STORE_IND => { + // Calculate the value to write let val: i64 = if instruction.store_ra { self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 } else { self.ctx.inst_ctx.c as i64 }; + + // Build the memory address let mut addr = instruction.store_offset; if instruction.store_use_sp { addr += self.ctx.inst_ctx.sp as i64; } addr += self.ctx.inst_ctx.a as i64; - self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, instruction.ind_width); + + // Call write_silent_required to get the additional data (aligned read values + // required to construct the new written data, if not aligned) + let additional_data: Vec = self.ctx.inst_ctx.mem.write_silent_required( + addr as u64, + val as u64, + instruction.ind_width, + ); + + // Store the written value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 2, + is_write: true, + address: addr as u32, + width: instruction.ind_width as u8, + value: val as u64, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } } _ => panic!( - "Emu::store_c_slice() Invalid store={} pc={}", + "Emu::store_c() Invalid store={} pc={}", instruction.store, self.ctx.inst_ctx.pc ), } } - // Set SP, if specified by the current instruction + /// Set SP, if specified by the current instruction // #[cfg(feature = "sp")] // #[inline(always)] // pub fn set_sp(&mut self, instruction: &ZiskInst) { @@ -449,6 +655,22 @@ impl<'a> Emu<'a> { (emu_traces, emu_segments) } + pub fn par_run_memory(&mut self, inputs: Vec) -> Vec { + // Context, where the state of the execution is stored and modified at every execution step + self.ctx = self.create_emu_context(inputs); + + // Init pc to the rom entry address + self.ctx.trace.start_state.pc = ROM_ENTRY; + + let mut emu_mem = Vec::new(); + + while !self.ctx.inst_ctx.end { + self.par_step_memory::(&mut emu_mem); + } + + emu_mem + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] @@ -456,8 +678,13 @@ impl<'a> Emu<'a> { let pc = self.ctx.inst_ctx.pc; let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); - //println!("Emu::step() executing step={} pc={:x} inst={}", ctx.step, ctx.pc, - // inst.i.to_string()); println!("Emu::step() step={} pc={}", ctx.step, ctx.pc); + /*println!( + "Emu::step() executing step={} pc={:x} inst={}", + self.ctx.inst_ctx.step, + self.ctx.inst_ctx.pc, + instruction.to_text() + );*/ + //println!("Emu::step() step={} pc={}", ctx.step, ctx.pc); // Build the 'a' register value based on the source specified by the current instruction self.source_a(instruction); @@ -622,6 +849,56 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.step += 1; } + /// Performs one single step of the emulation + #[inline(always)] + #[allow(unused_variables)] + pub fn par_step_memory(&mut self, emu_mem: &mut Vec) { + //let last_pc = self.ctx.inst_ctx.pc; + //let last_c = self.ctx.inst_ctx.c; + + let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); + + // println!( + // "#### step={} pc={} op={}={} a={} b={} c={} flag={} inst={}", + // self.ctx.inst_ctx.step, + // self.ctx.inst_ctx.pc, + // instruction.op, + // instruction.op_str, + // self.ctx.inst_ctx.a, + // self.ctx.inst_ctx.b, + // self.ctx.inst_ctx.c, + // self.ctx.inst_ctx.flag, + // instruction.to_text() + // ); + // self.print_regs(); + // println!(); + + // Build the 'a' register value based on the source specified by the current instruction + self.source_a_memory(instruction, emu_mem); + + // Build the 'b' register value based on the source specified by the current instruction + self.source_b_memory(instruction, emu_mem); + + // Call the operation + (instruction.func)(&mut self.ctx.inst_ctx); + + // Store the 'c' register value based on the storage specified by the current instruction + self.store_c_memory(instruction, emu_mem); + + // Set SP, if specified by the current instruction + // #[cfg(feature = "sp")] + // self.set_sp(instruction); + + // Set PC, based on current PC, current flag and current instruction + self.set_pc(instruction); + + // If this is the last instruction, stop executing + self.ctx.inst_ctx.end = instruction.end; + + // Increment step counter + self.ctx.inst_ctx.step += 1; + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] @@ -742,7 +1019,7 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.a = trace_step.a; self.ctx.inst_ctx.b = trace_step.b; (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + // No need to store c // #[cfg(feature = "sp")] // self.set_sp(instruction); self.set_pc(instruction); @@ -788,7 +1065,7 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.a = trace_step.a; self.ctx.inst_ctx.b = trace_step.b; (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + // No need to store c // #[cfg(feature = "sp")] // self.set_sp(instruction); self.set_pc(instruction); diff --git a/emulator/src/emu_context.rs b/emulator/src/emu_context.rs index 811baa6c..6ed4d581 100644 --- a/emulator/src/emu_context.rs +++ b/emulator/src/emu_context.rs @@ -64,6 +64,28 @@ impl EmuContext { impl Default for EmuContext { fn default() -> Self { - Self::new(Vec::new()) + EmuContext { + inst_ctx: InstContext { + mem: Mem::new(), + a: 0, + b: 0, + c: 0, + flag: false, + sp: 0, + pc: ROM_ENTRY, + step: 0, + end: false, + }, + tracerv: Vec::new(), + tracerv_step: 0, + tracerv_current_regs: [0; 32], + trace_pc: 0, + trace: EmuTrace::default(), + do_callback: false, + callback_steps: 0, + last_callback_step: 0, + do_stats: false, + stats: Stats::default(), + } } } diff --git a/emulator/src/emu_trace.rs b/emulator/src/emu_trace.rs index 7baff978..75b4b138 100644 --- a/emulator/src/emu_trace.rs +++ b/emulator/src/emu_trace.rs @@ -13,9 +13,10 @@ pub struct EmuTraceStart { pub step: u64, } -/// Trace data at every step. -/// Only the values of registers a and b are required. -/// The current value of pc evolves starting at the start pc value, as we execute the ROM. +/// Trace data at every step. +/// +/// Only the values of registers a and b are required. +/// The current value of pc evolves starting at the start pc value, as we execute the ROM. /// The value of c and flag can be obtained by executing the ROM instruction corresponding to the /// current value of pc and taking a and b as the input. #[derive(Default, Debug, Clone)] @@ -26,12 +27,13 @@ pub struct EmuTraceStep { pub b: u64, } -/// Trace data at the end of the program execution, including only the `end` flag. -/// If the `end` flag is true, the program executed completely. +/// Trace data at the end of the program execution, including only the `end` flag. +/// +/// If the `end` flag is true, the program executed completely. /// This does not mean that the program ended successfully; it could have found an error condition -/// due to, for example, invalid input data, and then jump directly to the end of the ROM. +/// due to, for example, invalid input data, and then jump directly to the end of the ROM. /// In this error situation, the output data should reveal the success or fail of the completed -/// execution. +/// execution. /// These are the possible combinations: /// * end = false --> program did not complete, e.g. the emulator run out of steps (you can /// configure more steps) diff --git a/emulator/src/emulator.rs b/emulator/src/emulator.rs index 36d3bceb..49d7b8d2 100644 --- a/emulator/src/emulator.rs +++ b/emulator/src/emulator.rs @@ -28,8 +28,8 @@ use std::{ }; use sysinfo::System; use zisk_core::{ - Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - ZISK_OPERATION_TYPE_VARIANTS, + Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, ZiskRequiredOperation, + ZiskRom, ZISK_OPERATION_TYPE_VARIANTS, }; pub trait Emulator { @@ -261,6 +261,22 @@ impl ZiskEmulator { Ok((vec_traces, emu_slices)) } + pub fn par_process_rom_memory( + rom: &ZiskRom, + inputs: &[u8], + ) -> Result, ZiskEmulatorErr> { + let mut emu = Emu::new(rom); + let result = emu.par_run_memory::(inputs.to_owned()); + + if !emu.terminated() { + panic!("Emulation did not complete"); + // TODO! + // return Err(ZiskEmulatorErr::EmulationNoCompleted); + } + + Ok(result) + } + /// Process a Zisk rom with the provided input data, according to the configured options, in /// order to generate a set of required operation data. #[inline] diff --git a/pil/src/lib.rs b/pil/src/lib.rs index aee8bab5..27705cb0 100644 --- a/pil/src/lib.rs +++ b/pil/src/lib.rs @@ -6,8 +6,5 @@ pub use pil_helpers::*; pub const ARITH32_AIR_IDS: &[usize] = &[4, 5]; pub const ARITH64_AIR_IDS: &[usize] = &[6]; pub const ARITH3264_AIR_IDS: &[usize] = &[7]; -pub const MEM_AIRGROUP_ID: usize = 105; -pub const MEM_ALIGN_AIR_IDS: &[usize] = &[1]; -pub const MEM_UNALIGNED_AIR_IDS: &[usize] = &[2, 3]; pub const QUICKOPS_AIRGROUP_ID: usize = 102; pub const QUICKOPS_AIR_IDS: &[usize] = &[10]; diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 9098a62b..919399c6 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -14,21 +14,35 @@ pub const MAIN_AIR_IDS: &[usize] = &[0]; pub const ROM_AIR_IDS: &[usize] = &[1]; -pub const ARITH_AIR_IDS: &[usize] = &[2]; +pub const MEM_AIR_IDS: &[usize] = &[2]; -pub const ARITH_TABLE_AIR_IDS: &[usize] = &[3]; +pub const ROM_DATA_AIR_IDS: &[usize] = &[3]; -pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[4]; +pub const INPUT_DATA_AIR_IDS: &[usize] = &[4]; -pub const BINARY_AIR_IDS: &[usize] = &[5]; +pub const MEM_ALIGN_AIR_IDS: &[usize] = &[5]; -pub const BINARY_TABLE_AIR_IDS: &[usize] = &[6]; +pub const MEM_ALIGN_ROM_AIR_IDS: &[usize] = &[6]; -pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[7]; +pub const ARITH_AIR_IDS: &[usize] = &[7]; -pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[8]; +pub const ARITH_TABLE_AIR_IDS: &[usize] = &[8]; -pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[9]; +pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[9]; + +pub const BINARY_AIR_IDS: &[usize] = &[10]; + +pub const BINARY_TABLE_AIR_IDS: &[usize] = &[11]; + +pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[12]; + +pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[13]; + +pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[14]; + +pub const U_8_AIR_AIR_IDS: &[usize] = &[15]; + +pub const U_16_AIR_AIR_IDS: &[usize] = &[16]; pub struct Pilout; @@ -39,7 +53,12 @@ impl Pilout { let air_group = pilout.add_air_group(Some("Zisk")); air_group.add_air(Some("Main"), 2097152); - air_group.add_air(Some("Rom"), 1048576); + air_group.add_air(Some("Rom"), 4194304); + air_group.add_air(Some("Mem"), 2097152); + air_group.add_air(Some("RomData"), 2097152); + air_group.add_air(Some("InputData"), 2097152); + air_group.add_air(Some("MemAlign"), 2097152); + air_group.add_air(Some("MemAlignRom"), 256); air_group.add_air(Some("Arith"), 2097152); air_group.add_air(Some("ArithTable"), 128); air_group.add_air(Some("ArithRangeTable"), 4194304); @@ -48,6 +67,8 @@ impl Pilout { air_group.add_air(Some("BinaryExtension"), 2097152); air_group.add_air(Some("BinaryExtensionTable"), 4194304); air_group.add_air(Some("SpecifiedRanges"), 16777216); + air_group.add_air(Some("U8Air"), 256); + air_group.add_air(Some("U16Air"), 65536); pilout } diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 47541a74..6137366d 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -11,6 +11,26 @@ trace!(RomRow, RomTrace { multiplicity: F, }); +trace!(MemRow, MemTrace { + addr: F, step: F, sel: F, addr_changes: F, value: [F; 2], wr: F, increment: F, +}); + +trace!(RomDataRow, RomDataTrace { + addr: F, step: F, sel: F, addr_changes: F, value: [F; 2], +}); + +trace!(InputDataRow, InputDataTrace { + addr: F, step: F, sel: F, addr_changes: F, value_word: [F; 4], +}); + +trace!(MemAlignRow, MemAlignTrace { + addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, delta_addr: F, sel_prove: F, value: [F; 2], +}); + +trace!(MemAlignRomRow, MemAlignRomTrace { + multiplicity: F, +}); + trace!(ArithRow, ArithTrace { carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, sext: F, m32: F, div: F, fab: F, na_fb: F, nb_fa: F, debug_main_step: F, main_div: F, main_mul: F, signed: F, div_by_zero: F, div_overflow: F, inv_sum_all_bs: F, op: F, bus_res1: F, multiplicity: F, range_ab: F, range_cd: F, }); @@ -40,7 +60,15 @@ trace!(BinaryExtensionTableRow, BinaryExtensionTableTrace { }); trace!(SpecifiedRangesRow, SpecifiedRangesTrace { - mul: [F; 1], + mul: [F; 2], +}); + +trace!(U8AirRow, U8AirTrace { + mul: F, +}); + +trace!(U16AirRow, U16AirTrace { + mul: F, }); trace!(RomRomRow, RomRomTrace { diff --git a/pil/zisk.pil b/pil/zisk.pil index 6dc5052c..bce36ffa 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -1,17 +1,27 @@ require "rom/pil/rom.pil" require "main/pil/main.pil" +require "mem/pil/mem.pil" +require "mem/pil/mem_align.pil" +require "mem/pil/mem_align_rom.pil" require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" require "binary/pil/binary_extension.pil" require "binary/pil/binary_extension_table.pil" require "arith/pil/arith.pil" -// require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; + airgroup Zisk { Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); - Rom(N: 2**20); - // Mem(N: 2**21, RC: 2); + Rom(N: 2**22); + + Mem(N: 2**21, RC: 2, base_address: 0xA000_0000); + Mem(N: 2**21, RC: 2, base_address: 0x8000_0000, immutable: 1) alias RomData; + Mem(N: 2**21, RC: 2, base_address: 0x9000_0000, free_input_mem: 1) alias InputData; + MemAlign(N: 2**21); + MemAlignRom(disable_fixed: 0); + // InputData(N: 2**21, RC: 2); + Arith(N: 2**21, operation_bus_id: OPERATION_BUS_ID); ArithTable(); ArithRangeTable(); diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 6788f7de..e8bd35d7 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -225,9 +225,9 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) { RANGE_CD[index] = range_cd; if (generate_table) { - println(`OP:${opcode} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} m32:${m32} div:${div}`, - `div_by_zero:${div_by_zero} div_overflow:${div_overflow} sa:${sa} sb:${sb} main_mul:${main_mul}`, - `main_div:${main_div} signed:${signed} range_ab:${range_ab} range_cd:${range_cd} index:${(opcode - 0xb0) * 128 + icase} icase:${icase}`); + // println(`OP:${opcode} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} m32:${m32} div:${div}`, + // `div_by_zero:${div_by_zero} div_overflow:${div_overflow} sa:${sa} sb:${sb} main_mul:${main_mul}`, + // `main_div:${main_div} signed:${signed} range_ab:${range_ab} range_cd:${range_cd} index:${(opcode - 0xb0) * 128 + icase} icase:${icase}`); op2row[(opcode - 0xb0) * 128 + icase] = index; code = code + `[${opcode}, ${flags}, ${range_ab}, ${range_cd}],`; diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 99d71a65..afa3bd3c 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -78,7 +78,7 @@ impl ArithFullSM { let num_rows = air.num_rows(); timer_start_trace!(ARITH_TRACE); info!( - "{}: ··· Creating Arith instance KKKKK [{} / {} rows filled {:.2}%]", + "{}: ··· Creating Arith instance [{} / {} rows filled {:.2}%]", Self::MY_NAME, input.len(), num_rows, @@ -259,7 +259,6 @@ impl ArithFullSM { } timer_stop_and_log_trace!(ARITH_PADDING); timer_start_trace!(ARITH_TABLE); - info!("{}: ··· calling arit_table_sm", Self::MY_NAME); self.arith_table_sm.process_slice(&table_inputs); timer_stop_and_log_trace!(ARITH_TABLE); timer_start_trace!(ARITH_RANGE_TABLE); diff --git a/state-machines/arith/src/arith_table.rs b/state-machines/arith/src/arith_table.rs index dc535754..79fdbb2a 100644 --- a/state-machines/arith/src/arith_table.rs +++ b/state-machines/arith/src/arith_table.rs @@ -58,9 +58,7 @@ impl ArithTableSM { // Create the trace vector let mut _multiplicity = self.multiplicity.lock().unwrap(); - info!("{}: ··· process multiplicity", Self::MY_NAME); for (row, value) in inputs { - info!("{}: ··· Processing row {} with value {}", Self::MY_NAME, row, value); _multiplicity[row] += value; } self.used.store(true, Ordering::Relaxed); diff --git a/state-machines/arith/src/arith_table_helpers.rs b/state-machines/arith/src/arith_table_helpers.rs index ba557e07..67f2a730 100644 --- a/state-machines/arith/src/arith_table_helpers.rs +++ b/state-machines/arith/src/arith_table_helpers.rs @@ -25,9 +25,9 @@ impl ArithTableHelpers { sext as u64 * 16 + div_by_zero as u64 * 32 + div_overflow as u64 * 64; - assert!(index < ARITH_TABLE_ROWS.len() as u64); + debug_assert!(index < ARITH_TABLE_ROWS.len() as u64); let row = ARITH_TABLE_ROWS[index as usize]; - assert!( + debug_assert!( row < 255, "INVALID ROW row:{} op:0x{:x} na:{} nb:{} np:{} nr:{} sext:{} div_by_zero:{} div_overflow:{} index:{}", row, diff --git a/state-machines/binary/pil/binary_extension_table.pil b/state-machines/binary/pil/binary_extension_table.pil index 35e0ad35..f521b550 100644 --- a/state-machines/binary/pil/binary_extension_table.pil +++ b/state-machines/binary/pil/binary_extension_table.pil @@ -4,9 +4,9 @@ require "std_lookup.pil" // Operations Table: // Running Total // SLL (OP:0x31) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^19 -// SRL (OP:0x32) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 +// SRL (OP:0x32) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 // SRA (OP:0x33) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 + 2^19 -// SLL_W (OP:0x34) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 +// SLL_W (OP:0x34) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 // SRL_W (OP:0x35) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^19 // SRA_W (OP:0x36) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^20 // SE_B (OP:0x37) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 + 2^11 @@ -16,7 +16,7 @@ require "std_lookup.pil" const int BINARY_EXTENSION_TABLE_ID = 124; airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = 0) { - + #pragma memory m1 start const int SE_MASK_32 = 0xFFFFFFFF00000000; @@ -28,15 +28,15 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = const int LS_5_BITS = 0x1F; const int LS_6_BITS = 0x3F; - + col witness multiplicity; if (disable_fixed) { col fixed _K = [0...]; // FORCE ONE TRACE multiplicity * _K === 0; - - println("*** DISABLE_FIXED ***"); + + println("*** DISABLE_FIXED ***"); return; } @@ -58,7 +58,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = // Input B (8 bits) col fixed B = [[0:P2_11..255:P2_11]:6, // SLL, SRL, SRA, SLL_W, SRL_W, SRA_W - 0:(P2_11*3)]...; // SE_B, SE_H, SE_W + 0:(P2_11*3)]...; // SE_B, SE_H, SE_W // Operation is shift (fixed values) col fixed OP_IS_SHIFT = [1:(P2_19*6), // SLL, SRL, SRA, SLL_W, SRL_W, SRA_W @@ -84,12 +84,12 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = const int _a = a << (8*offset); switch (op) { case 0x31: // SLL - _out = _a << (b & LS_6_BITS); + _out = _a << (b & LS_6_BITS); case 0x32: // SRL _out = _a >> (b & LS_6_BITS); - case 0x33: { // SRA + case 0x33: { // SRA const int _b = b & LS_6_BITS; _out = _a >> _b; if (offset == 7) { @@ -110,7 +110,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = _out = _out | SE_MASK_32; } } - + case 0x35: // SRL_W if (offset >= 4) { // last most significant bytes are ignored because it's 32-bit operation @@ -148,7 +148,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } case 0x38: // SE_H - if (offset == 0) { + if (offset == 0) { // fist byte not define the sign extend, but participate of result _out = a; } else if (offset == 1) { diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index aefe09f6..b37c5a22 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -10,7 +10,7 @@ use proofman_common::AirInstance; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; use std::cmp::Ordering as CmpOrdering; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; +use zisk_pil::{BinaryRow, BinaryTrace, BINARY_AIR_IDS, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index ded8972f..4e22f8fe 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -15,7 +15,10 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; +use zisk_pil::{ + BinaryExtensionRow, BinaryExtensionTrace, BINARY_EXTENSION_AIR_IDS, + BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID, +}; const MASK_32: u64 = 0xFFFFFFFF; const MASK_64: u64 = 0xFFFFFFFFFFFFFFFF; diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 740a6e63..6d8d7506 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -80,7 +80,7 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness air.b_imm1; } col witness b_src_ind; - col witness ind_width; // 8 , 4, 2, 1 + col witness ind_width; // 8, 4, 2, 1 // Operations related @@ -113,8 +113,6 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness jmp_offset1, jmp_offset2; // if flag, goto2, else goto 1 col witness m32; - const expr addr_step = STEP * 3; - const expr sel_mem_b; sel_mem_b = b_src_mem + b_src_ind; @@ -136,17 +134,18 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope } // Mem.load - //mem_load(sel: a_src_mem, - // step: addr_step, - // addr: addr0, - // value: a); + mem_load(sel: a_src_mem, + step: STEP, + addr: addr0, + value: a); // Mem.load - //mem_load(sel: sel_mem_b, - // step: addr_step + 1, - // bytes: ind_width, - // addr: addr1, - // value: b); + mem_load(sel: sel_mem_b, + step: STEP, + step_offset: 1, + bytes: b_src_ind * (ind_width - 8) + 8, + addr: addr1, + value: b); const expr store_value[2]; @@ -154,11 +153,12 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope store_value[1] = (1 - store_ra) * c[1]; // Mem.store - //mem_store(sel: store_mem + store_ind, - // step: addr_step + 2, - // bytes: ind_width, - // addr: addr2, - // value: store_value); + mem_store(sel: store_mem + store_ind, + step: STEP, + step_offset: 2, + bytes: store_ind * (ind_width - 8) + 8, + addr: addr2, + value: store_value); // Operation.assume => how organize software col witness __debug_operation_bus_enabled; @@ -241,12 +241,8 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope // const expr bus_main_segment = main_segment - SEGMENT_LAST * (main_segment * main_last_segment - 1 + main_last_segment); - // permutation_proves(MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, ...specific_registers, c[0] * (1 - main_last_segment), c[1] * (1 - main_last_segment)], - // sel: SEGMENT_LAST - SEGMENT_L1, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); - permutation_proves(MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, pc, c[0], c[1], set_pc, jmp_offset1, flag * SEGMENT_LAST * (jmp_offset1 - jmp_offset2) + jmp_offset2], - sel: SEGMENT_LAST, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); - permutation_assumes(MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, pc, c[0], c[1], set_pc, jmp_offset1, flag * SEGMENT_LAST * (jmp_offset1 - jmp_offset2) + jmp_offset2], - sel: SEGMENT_L1, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); + permutation (MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, pc, c[0], c[1], set_pc, jmp_offset1, flag * SEGMENT_LAST * (jmp_offset1 - jmp_offset2) + jmp_offset2], + sel: SEGMENT_LAST - SEGMENT_L1, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); flag * (1 - flag) === 0; diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index a7bcef56..52045db4 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -1,5 +1,6 @@ use log::info; use p3_field::PrimeField; +use sm_mem::MemProxy; use crate::InstanceExtensionCtx; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; @@ -12,7 +13,6 @@ use proofman_common::{AirInstance, ProofCtx}; use proofman::WitnessComponent; use sm_arith::ArithSM; -use sm_mem::MemSM; use zisk_pil::{ ArithTrace, BinaryExtensionTrace, BinaryTrace, MainRow, MainTrace, ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, @@ -28,14 +28,14 @@ pub struct MainSM { /// Witness computation manager wcm: Arc>, + /// Memory state machine + mem_proxy_sm: Arc>, + /// Arithmetic state machine arith_sm: Arc>, /// Binary state machine binary_sm: Arc>, - - /// Memory state machine - mem_sm: Arc, } impl MainSM { @@ -54,16 +54,16 @@ impl MainSM { /// * Arc to the MainSM state machine pub fn new( wcm: Arc>, + mem_proxy_sm: Arc>, arith_sm: Arc>, binary_sm: Arc>, - mem_sm: Arc, ) -> Arc { - let main_sm = Arc::new(Self { wcm: wcm.clone(), arith_sm, binary_sm, mem_sm }); + let main_sm = Arc::new(Self { wcm: wcm.clone(), mem_proxy_sm, arith_sm, binary_sm }); wcm.register_component(main_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MAIN_AIR_IDS)); // For all the secondary state machines, register the main state machine as a predecessor - main_sm.mem_sm.register_predecessor(); + main_sm.mem_proxy_sm.register_predecessor(); main_sm.binary_sm.register_predecessor(); main_sm.arith_sm.register_predecessor(); @@ -153,6 +153,39 @@ impl MainSM { segment_trace.steps[slice_start..slice_end].iter().enumerate() { partial_trace[i] = emu.step_slice_full_trace(emu_trace_step); + // if partial_trace[i].a_src_mem == F::one() { + // println!( + // "A=MEM_OP_RD({}) [{},{}] PC:{}", + // partial_trace[i].a_offset_imm0, + // partial_trace[i].a[0], + // partial_trace[i].a[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == + // F::one() { + // println!( + // "B=MEM_OP_RD({0}) [{1},{2}] PC:{3}", + // partial_trace[i].addr1, + // partial_trace[i].b[0], + // partial_trace[i].b[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == + // F::one() { + // println!( + // "MEM_OP_WR({}) [{}, {}] PC:{}", + // partial_trace[i].store_offset + // + partial_trace[i].store_ind * partial_trace[i].a[0], + // partial_trace[i].store_ra + // * (partial_trace[i].pc + partial_trace[i].jmp_offset2 + // - partial_trace[i].c[0]) + // + partial_trace[i].c[0], + // (F::one() - partial_trace[i].store_ra) * partial_trace[i].c[1], + // partial_trace[i].pc + // ); + // } } // if there are steps in the chunk update last row if slice_end - slice_start > 0 { diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index 3f8ee914..7cdb344d 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -7,14 +7,21 @@ edition = "2021" sm-common = { path = "../common" } zisk-core = { path = "../../core" } zisk-pil = { path = "../../pil" } +num-traits = "0.2" -p3-field = { workspace=true } proofman-common = { workspace = true } proofman-macros = { workspace = true } +proofman-util = { workspace = true } proofman = { workspace = true } +pil-std-lib = { workspace = true } + +p3-field = { workspace=true } log = { workspace = true } rayon = { workspace = true } +num-bigint = { workspace = true } [features] default = [] -no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] \ No newline at end of file +no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] +debug_mem_proxy_engine = [] +debug_mem_align = [] \ No newline at end of file diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 50bd652e..f574f264 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -1,3 +1,42 @@ +/* + Memory Component (Mem) + ====================== + + - Allows to define a memory on a region with size <= MEMORY_MAX_DIFF * mem_bytes (2^24). + - Inside this component the address are mem-byte address, to translate internal adress to + external need to multiply by mem_bytes. + - For executors optimization, external addresses use 32-bits. + - The memory regions must be exclusive, to avoid collisions between different memories. + - The constraints over instances guarantees that the memory access are inside definited region. + - The constraints guarantees that only one cyle for memory region is allowed. + - For non-aligned or for non mem-bytes access, the MemAlign machine was used. + + Parameters: + + - N = number of rows + - id = bus_id used of memory operations + - RC = number of value chunks (2 by default) + - mem_bytes = number of bytes of memory word (8 bytes by default) + - base_address = base byte address when start the memory + - mem_size = size of memory in bytes (0x800_0000 by default) + - immutable = if memory is immutable, first access is a write (by default is mutable) + - free_input_mem = if memory is a free input memory, memory without write, all access are reads + with same value, this value it's stablished by executor. + + Continuations: + + - The memory continuation is used to proves the last row significant values of the current segment, + and the next segment assume these significant values. + - The first assume of memory is generated by global constraint to guarantees only one cycle by + memory region. + - In the last segment, the proves are not generated to avoid generate more than one memory cycle. + - The constraints that refer to the values of the previous row, in the first row, take the value + from the airvalue previous_segment_xxx, which contains the value at the end of the previous segment. + - These previous airvalues are validated throw bus, because assume these values at end of previous + segment. + +*/ + require "std_permutation.pil" require "std_range_check.pil" @@ -6,101 +45,191 @@ const int MEMORY_CONT_ID = 11; const int MEMORY_LOAD_OP = 1; const int MEMORY_STORE_OP = 2; -const int MEMORY_MAX_DIFF = 2**22; +const int MEMORY_MAX_DIFF = 2**24; -const int MAX_MEM_STEP_OFFSET = 3; +const int MAX_MEM_STEP_OFFSET = 2; +const int MAX_MEM_OPS_PER_MAIN_STEP = (MAX_MEM_STEP_OFFSET + 1) * 2; -airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 ** 23, int MEM_BYTES = 8 ) { +airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, const int mem_bytes = 8, const int base_address = 0, const int mem_size = 0x800_0000, int immutable = 0, const int free_input_mem = 0) { col fixed SEGMENT_L1 = [1,0...]; const expr SEGMENT_LAST = SEGMENT_L1'; - airval mem_segment; - airval mem_last_segment; + // in this air the address in a mem-bytes address (internal), when this address is pushed in BUS must be multiplied + // by mem_bytes to get the real address. + + const expr internal_base_address = base_address / mem_bytes; + const expr internal_end_address = (base_address + mem_size - 1) / mem_bytes; + airval segment_id; + airval is_first_segment; + airval is_last_segment; - col witness addr; // n-byte address, real address = addr * MEM_BYTES + is_first_segment * (1 - is_first_segment) === 0; + is_last_segment * (1 - is_last_segment) === 0; + is_first_segment * segment_id === 0; + + col witness addr; // n-byte address, real address = addr * mem_bytes col witness step; - col witness sel, wr; - col witness value[RC]; + col witness sel; col witness addr_changes; - const expr rd = (1 - wr); - sel * (1 - sel) === 0; - wr * (1 - wr) === 0; + if (!free_input_mem) { + col witness air.value[RC]; + } else { + immutable = 1; + col witness air.value_word[RC*2]; + const expr air.value[RC]; + for (int index = 0; index < RC; ++index) { + value[index] = value_word[index*2] + 2**16 * value_word[index*2 + 1]; + + // how value is a free-input, must be checked that it's 32-bit well formed value + range_check(value_word[index*2], 0, 2**16 - 1); + range_check(value_word[index*2+1], 0, 2**16 - 1); + } + } + if (!immutable) { + col witness air.wr; + const expr air.rd = 1 - wr; + wr * (1 - wr) === 0; + } else { + // a free input memory must be read-only, an immutable memory must be write + // on first row of new address (addr_changes = 1) + const expr air.wr = free_input_mem ? 0 : addr_changes; + } // if wr is 1, sel must be 1 (not allowed writes) wr * (1 - sel) === 0; - // all time first line is lost, used for continuations - sel * SEGMENT_L1 === 0; + sel * (1 - sel) === 0; addr_changes * (1 - addr_changes) === 0; + airval previous_segment_value[RC]; + airval previous_segment_step, previous_segment_addr; + + // continuation for next segment, these values used on direct update to air bus, and after + // with constraints force that these values are the same as last row of current segment. + + airval segment_last_value[RC]; + airval segment_last_step, segment_last_addr; + + for (int i = 0; i < RC; i++) { + SEGMENT_LAST * (value[i] - segment_last_value[i]) === 0; + } + + SEGMENT_LAST * (addr - segment_last_addr) === 0; + SEGMENT_LAST * (step - segment_last_step) === 0; + + // add base_address to the columns to avoid collisions between different memories + // for security send is_last_segment to avoid reuse end of last segment as start of new cycle of segments + direct_update_assumes(MEMORY_CONT_ID, + [ + base_address, // identify area of memory + segment_id, // current segment_id + // proves of last segment + previous_segment_addr, + previous_segment_step, + ...previous_segment_value + ]); + + direct_update_proves(MEMORY_CONT_ID, + [ + base_address, // identify area of memory + segment_id + 1, // next segment_id, for last segment + // this value is forced to 0 to match global constraint + segment_last_addr, // last addr of segment + segment_last_step, // last step of segment + ...segment_last_value + ], + sel: (1 - is_last_segment)); + + const int zeros[air.RC]; + for (int i = 0; i < length(zeros); ++i) { + zeros[i] = 0; + } + direct_global_update_proves(MEMORY_CONT_ID, [ base_address, 0, internal_base_address, 0, ...zeros]); + + // for security check that first address has correct value, to avoid add huge quantity of instances to "overflow" prime field. + range_check(colu: previous_segment_addr - internal_base_address + 1, min: 1, max: MEMORY_MAX_DIFF); + + // control final of memory + range_check(colu: internal_end_address - segment_last_addr + 1, min: 1, max: MEMORY_MAX_DIFF); + + // check increment of memory - range_check(sel: (1 - SEGMENT_L1), colu: addr_changes * (addr - 'addr - step + 'step) + step - 'step, min: 1, max: MEMORY_MAX_DIFF); + if (immutable) { + // addresses are incremental, to save range check, increment, etc, address must be consecutive. + const expr air.previous_addr = SEGMENT_L1 * (previous_segment_addr - is_first_segment - 'addr) + 'addr; + const expr delta_addr = addr - previous_addr; + addr_changes * (delta_addr - 1) === 0; + (1 - addr_changes) * (addr - previous_addr) === 0; + } else { + const expr air.previous_addr = SEGMENT_L1 * (previous_segment_addr - 'addr) + 'addr; + const expr delta_addr = addr - previous_addr; - // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 - // setting mem_last_segment = 1 + // on first row of first segment could be the same and address_change = 1 because it's as a new addr + // SEGMENT_L1 * (x + is_first_segment * SEGMENT_L1) === SEGMENT_L1 * (x + is_first_segment) - // if addr_changes == 0 means that addr and previous address are the same - (1 - addr_changes) * ('addr - addr) === 0; + const expr previous_step = SEGMENT_L1 * (previous_segment_step - 'step) + 'step; + const expr delta_step = step - previous_step; - col witness same_value; - (1 - same_value) * (1 - wr) * (1 - addr_changes) === 0; + col witness increment; + increment === addr_changes * (delta_addr - delta_step) + delta_step; - col witness first_addr_access_is_read; - (1 - first_addr_access_is_read) * rd * (1 - addr_changes) === 0; + is_first_segment * SEGMENT_L1 * (1 - addr_changes) === 0; - for (int index = 0; index < length(value); index = index + 1) { - same_value * (value[index] - 'value[index]) === 0; - first_addr_access_is_read * value[index] === 0; + range_check(colu: increment, min: 1, max: MEMORY_MAX_DIFF); } - // CONTINUATIONS - // - // segments: S, S+1 - // - // CASE: last row of segment is read - // - // S[n-1] wr = 0, sel = 1, addr, step, value => BUS.proves(MEM_CONT_ID, S+1, addr, step-1, value) - // S+1[0] wr = 0, sel = 0, addr, step, value => BUS.assumes(MEM_CONT_ID, S, addr, step, value) - // - // CASE: last row of segment is write - // - // S[n-1] wr = 1, sel = 1, addr, step, value => BUS.proves(MEM_CONT_ID, S+1, addr, step-1, value) - // S+1[0] wr = 0, sel = 0, addr, step, value => BUS.assumes(MEM_CONT_ID, S, addr, step, value) - // - // NOTES: from row = 1 all constraints could be reference previous row, without problems - // on row = 0 forced by constraint that sel = 0 => wr = 0. - // on S+1[0].step = S[n-1].step - 1; - // - // FIRST SEGMENT: - // the BUS.proves needed by BUS.assumes of the first segment it's generated by global constraint to avoid - // generate more than one cycle of memory. In this constraint we could force the initial address (to split - // in two memories, one register-memory and other standard-memory). - // - // LAST SEGMENT: - // the last not used rows are filled with last addr and value and sel = 0 and wr = 0 incrementing steps. - // last BUS.proves not it's generated to avoid generate more than one memory cycle. - - // permutation_proves(MEMORY_CONT_ID, [(mem_segment + 1), addr, step, ...value], sel: mem_last_segment * 'SEGMENT_L1); // last row - // permutation_assumes(MEMORY_CONT_ID, [mem_segment, 0, addr, step, ...value], sel: SEGMENT_L1); // first row - - permutation_proves(MEMORY_ID, cols: [wr, addr * MEM_BYTES, step, MEM_BYTES, ...value], sel: sel); -} + (1 - addr_changes) * (addr - previous_addr) === 0; -// TODO: detect non default value but not called, mandatory parameter. -function mem_load(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { - if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 + // setting is_last_segment = 1 + + // if addr_changes == 0 means that addr and previous address are the same + // TODO: + + for (int index = 0; index < length(value); index++) { + const expr previous_value = SEGMENT_L1 * (previous_segment_value[index] - 'value[index]) + 'value[index]; + if (immutable) { + // if address not change value must be equal to previous value + (1 - addr_changes) * (value[index] - previous_value) === 0; + + if (!free_input_mem) { + // if address changes => write, and it must be inserted on bus + addr_changes * (1 - sel) === 0; + } + } else { + // if address not change and it isn't write, value must be equal to previous value + // TODO: boundary constraints + (1 - addr_changes) * (1 - wr) * (value[index] - previous_value) === 0; + + // if address changes, and it isn't a write, value must be 0. + addr_changes * (1 - wr) * value[index] === 0; + } } - // adding one for first continuation - permutation_assumes(id, [MEMORY_LOAD_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step) + step_offset, bytes, ...value], sel:sel); + + // The Memory component is only able to prove aligned memory access, since we force the bus address to be a multiple of mem_bytes + // and the width to be exactly mem_bytes + // Notice, however, that the main can also use widths of 4, 2, 1 and addresses that are not multiples of mem_bytes. + // These are handled with the Memory Align component + + const expr mem_op = wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP; + permutation_proves(MEMORY_ID, cols: [mem_op, addr * mem_bytes, step, mem_bytes, ...value], sel: sel); +} + +function mem_load(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_LOAD_OP, addr, step, step_offset, bytes, value, sel); } -function mem_store(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { +function mem_store(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_STORE_OP, addr, step, step_offset, bytes, value, sel); +} + +private function mem_assumes(int id, int mem_op, expr addr, expr step, expr step_offset, expr bytes, expr value[], expr sel) { if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); + error("step_offset ${step_offset} is greater than max value allowed ${MAX_MEM_STEP_OFFSET}"); } - // adding one for first continuation - permutation_assumes(id, [MEMORY_STORE_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step), bytes, ...value], sel:sel); -} \ No newline at end of file + + // adding 1 at step for first continuation + permutation_assumes(id, [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value], sel: sel); +} diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index e69de29b..8a23ab2a 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -0,0 +1,188 @@ +require "std_permutation.pil" +require "std_lookup.pil" +require "std_range_check.pil" + +// Problem to solve: +// ================= +// We are given an op (rd,wr), an addr, a step and a bytes-width (8,4,2,1) and we should prove that the memory access is correct. +// Note: Either the original addr is not a multiple of 8 or width < 8 to ensure it is a non-aligned access that should be +// handled by this component. + +/* + We will model it as a very specified processor with 8 registers and a very limited instruction set. + + This processor is limited to 4 possible subprograms: + + 1] Read operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+---+---+---+ + |<------ v ------>| + + [R] In the first clock cycle, we perform an aligned read to w + [V] In the second clock cycle, we return the demanded value v from w + + 2] Write operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+---+---+---+ +---+===+===+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+---+ + |<- v ->| + + [R] In the first clock cycle, we perform an aligned read to w + [W] In the second clock cycle, we compute an aligned write of v to w + [V] In the third clock cycle, we restore the demanded value from w + + 3] Read operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [V] In the second clock cycle, we return the demanded value v from w1 and w2 + [R] In the third clock cycle, we perform an aligned read to w2 + + 4] Write operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [W] In the second clock cycle, we compute an aligned write of v to w1 + [V] In the third clock cycle, we restore the demanded value from w1 and w2 + [R] In the fourth clock cycle, we perform an aligned read to w2 + [W] In the fiveth clock cycle, we compute an aligned write of v to w2 + + Example: + ========================================================== + (offset = 6, width = 4) + +----+----+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R1] (assume, up_to_down) sel = [1,1,1,1,1,1,0,0] + +----+----+----+----+----+----+----+----+ + ⇓ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W1] (assume, up_to_down) sel = [0,0,0,0,0,0,1,1] + +----+----+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+====+====+ + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | [V] (prove) (shift (offset + width) % 8) sel = [0,0,0,0,0,0,1,0] (*) + +====+====+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W2] (assume, down_to_up) sel = [1,1,0,0,0,0,0,0] + +====+====+----+----+----+----+----+----+ + ⇓ + +----+----+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R2] (assume, down_to_up) sel = [0,0,1,1,1,1,1,1] + +----+----+----+----+----+----+----+----+ + + (*) In this step, we use the selectors to indicate the "scanning" needed to form the bus value: + v_0 = sel[0] * [V1,V0,V7,V6] + sel[1] * [V0,V7,V6,V5] + sel[2] * [V7,V6,V5,V4] + sel[3] * [V6,V5,V4,V3] + v_1 = sel[4] * [V5,V4,V3,V2] + sel[5] * [V4,V3,V2,V1] + sel[6] * [V3,V2,V1,V0] + sel[7] * [V2,V1,V0,V7] + Notice that it is enough with 8 combinations. +*/ + +airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM = 8, const int CHUNK_BITS = 8) { + const int CHUNKS_BY_RC = CHUNK_NUM / RC; + + col witness addr; // CHUNK_NUM-byte address, real address = addr * CHUNK_NUM + col witness offset; // 0..7, position at which the operation starts + col witness width; // 1,2,4,8, width of the operation + col witness wr; // 1 if the operation is a write, 0 otherwise + col witness pc; // line of the program to execute + col witness reset; // 1 at the beginning of the operation (indicating an address reset), 0 otherwise + col witness sel_up_to_down; // 1 if the next value is the current value (e.g. R -> W) + col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) + col witness reg[CHUNK_NUM]; // Register values, 1 byte each + col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise + col witness step; // Memory step + + // 1] Ensure the MemAlign follows the program + + // Registers should be bytes and be shuch that: + // - reg' == reg in transitions R -> V, R -> W, W -> V, + // - 'reg == reg in transitions V <- W, W <- R, + // in any case, sel_up_to_down,sel_down_to_up are 0 in [V] steps. + for (int i = 0; i < CHUNK_NUM; i++) { + range_check(reg[i], 0, 2**CHUNK_BITS-1); + + (reg[i]' - reg[i]) * sel[i] * sel_up_to_down === 0; + ('reg[i] - reg[i]) * sel[i] * sel_down_to_up === 0; + } + + col fixed L1 = [1,0...]; + L1 * pc === 0; // The program should start at the first line + + // We compress selectors, so we should ensure they are binary + for (int i = 0; i < CHUNK_NUM; i++) { + sel[i] * (1 - sel[i]) === 0; + } + wr * (1 - wr) === 0; + reset * (1 - reset) === 0; + sel_up_to_down * (1 - sel_up_to_down) === 0; + sel_down_to_up * (1 - sel_down_to_up) === 0; + + expr flags = 0; + for (int i = 0; i < CHUNK_NUM; i++) { + flags += sel[i] * 2**i; + } + flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + + // Perform the lookup against the program + expr delta_pc; + col witness delta_addr; // Auxiliary column + delta_pc = pc' - pc; + delta_addr === (addr - 'addr) * (1 - reset); + lookup_assumes(MEM_ALIGN_ROM_ID, [pc, delta_pc, delta_addr, offset, width, flags]); + + // 2] Assume aligned memory accesses against the Memory component + const expr sel_assume = sel_up_to_down + sel_down_to_up; + + // Offset should be 0 in aligned memory accesses, but this is ensured by the rom + // Width should be 8 in aligned memory accesses, but this is ensured by the rom + + // On assume steps, we reconstruct the value from the registers directly + expr assume_val[RC]; + for (int rc_index = 0; rc_index < RC; rc_index++) { + assume_val[rc_index] = 0; + int base = 1; + for (int _offset = 0; _offset < CHUNKS_BY_RC; _offset++) { + assume_val[rc_index] += reg[_offset + rc_index * CHUNKS_BY_RC] * base; + base *= 256; + } + } + + // 3] Prove unaligned memory accesses against the Main component + col witness sel_prove; + + sel_prove * sel_assume === 0; // Disjoint selectors + + // On prove steps, we reconstruct the value in the correct manner chosen by the selectors + expr prove_val[RC]; + for (int rc_index = 0; rc_index < RC; rc_index++) { + prove_val[rc_index] = 0; + } + for (int _offset = 0; _offset < CHUNK_NUM; _offset++) { + for (int rc_index = 0; rc_index < RC; rc_index++) { + expr _tmp = 0; + int base = 1; + for (int ichunk = 0; ichunk < CHUNKS_BY_RC; ichunk++) { + _tmp += reg[(_offset + rc_index * CHUNKS_BY_RC + ichunk) % CHUNK_NUM] * base; + base *= 256; + } + prove_val[rc_index] += sel[_offset] * _tmp; + } + } + + // We prove and assume with the same permutation check but with disjoint and different sign selectors + col witness value[RC]; // Auxiliary columns + for (int i = 0; i < RC; i++) { + value[i] === sel_prove * prove_val[i] + sel_assume * assume_val[i]; + } + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...value], sel: sel_prove - sel_assume); +} \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil new file mode 100644 index 00000000..3d7735bf --- /dev/null +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -0,0 +1,323 @@ +require "std_lookup.pil" + +const int MEM_ALIGN_ROM_ID = 133; +const int MEM_ALIGN_ROM_SIZE = P2_8; + +airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int DEFAULT_OFFSET = 0, const int DEFAULT_WIDTH = 8, const int disable_fixed = 0) { + if (N < MEM_ALIGN_ROM_SIZE) { + error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); + } + + col witness multiplicity; + + if (disable_fixed) { + col fixed _K = [0...]; + multiplicity * _K === 0; + + println("*** DISABLE_FIXED ***"); + return; + } + + // Define the size of each sub-program: RV, RWV, RVR, RWVWR + const int spsize[4] = [2, 3, 3, 5]; + + // Not all combinations of offset and width are valid for each program: + const int one_word_combinations = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 + const int two_word_combinations = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 + + // table_size = combinations * program_size + const int tsize[4] = [one_word_combinations*spsize[0], one_word_combinations*spsize[1], two_word_combinations*spsize[2], two_word_combinations*spsize[3]]; + const int psize = tsize[0] + tsize[1] + tsize[2] + tsize[3]; + + // Offset is set to DEFAULT_OFFSET and width to DEFAULT_WIDTH in aligned memory accesses. + // Offset and width are set to 0 in padding lines. + // size + col fixed OFFSET = [0, // Padding 1 = 1 | 1 + [[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 41 + [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 101 + [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 134 + [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3], // RWVWR 5*4+10+10+15 = 55 | 189 => N = 2^8 + 0...]; // Padding + + col fixed WIDTH = [0, // Padding + [[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV + [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV + [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR + [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]], // RWVWR + 0...]; // Padding + + // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | + // 0 | 0 | 0 | 1 | 0 | 0 | // for padding + // 1 | 0 | 1 | 1 | X1 | 0 | // (RV) + // 2 | 1 | -1 | 0 | X1 | 0 | + // 3 | 0 | 3 | 1 | X2 | 0 | // (RV) + // 4 | 3 | -3 | 0 | X2 | 0 | + // 5 | 0 | 5 | 1 | X3 | 0 | // (RV) + // 6 | 5 | -5 | 0 | X3 | 0 | + // 7 | 0 | 7 | 1 | ⋮ | ⋮ | // (RV) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 41 | 0 | 41 | 1 | X4 | 0 | // (RWV) + // 42 | 41 | 1 | 0 | X4 | 0 | + // 43 | 42 | -42 | 0 | X4 | 0 | + // 44 | 0 | 44 | 1 | X5 | 0 | // (RWV) + // 45 | 44 | 1 | 0 | X5 | 0 | + // 46 | 45 | -45 | 0 | X5 | 0 | + // 47 | 0 | 47 | 1 | X6 | 0 | // (RWV) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 101 | 0 | 101 | 1 | X7 | 0 | // (RVR) + // 102 |101 | 1 | 0 | X7 | 0 | + // 103 |102 | -102 | 0 | X7+1 | 1 | + // 104 | 0 | 104 | 1 | X8 | 0 | // (RVR) + // 105 |104 | 1 | 0 | X8 | 0 | + // 106 |105 | -105 | 0 | X8+1 | 1 | + // 107 | 0 | 107 | 1 | X9 | 0 | // (RVR) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 134 | 0 | 134 | 1 | X10 | 0 | // (RWVWR) + // 135 |134 | 1 | 0 | X10 | 0 | + // 136 |135 | 1 | 0 | X10 | 0 | + // 137 |136 | 1 | 0 | X10+1 | 1 | + // 138 |137 | -137 | 0 | X10+1 | 0 | + // 139 | 0 | 139 | 1 | X11 | 0 | // (RWVWR) + // 140 |139 | 1 | 0 | X11 | 0 | + // 141 |140 | 1 | 0 | X11 | 0 | + // 142 |141 | 1 | 0 | X11+1 | 1 | + // 143 |142 | -142 | 0 | X11+1 | 0 | + // 144 | 0 | 144 | 1 | X12 | 0 | // (RWVWR) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 188 |187 | -187 | 0 | X13+1 | 0 | + // 189 | 0 | 0 | 1 | 0 | 0 | // for padding + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + + // Note: The overall program contains "holes", meaning that pc can vary + // from program to program by any constant, as long as it is unique for each program. + // For example, the first program has pc=0,1, while the second has pc=0,3. + + col fixed PC; + col fixed DELTA_PC; + col fixed DELTA_ADDR; + col fixed FLAGS; + for (int i = 0; i < N; i++) { + int pc = 0; + int delta_pc = 0; + int delta_addr = 0; + int is_write = 0; + int reset = 0; + int sel[CHUNK_NUM]; + for (int j = 0; j < CHUNK_NUM; j++) { + sel[j] = 0; + } + int sel_up_to_down = 0; + int sel_down_to_up = 0; + + const int prev_line = i == 0 ? 0 : i-1; + const int line = i; + if (line == 0 || line > psize) + { + // pc = 0; + // delta_pc = 0; + // delta_addr = 0; + // is_write = 0; + reset = 1; + // sel = [0:CHUNK_NUM] + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + else if (line < 1+tsize[0]) // RV + { + if (line % 2 == 1) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < 1+tsize[0]+tsize[1]) // RWV + { + if (line % 3 == 2) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2] || j >= OFFSET[i+2] + WIDTH[i+2]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < 1+tsize[0]+tsize[1]+tsize[2]) // RVR + { + if (line % 3 == 2) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + delta_addr = 1; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + else if (line < 1+tsize[0]+tsize[1]+tsize[2]+tsize[3]) // RWVWR + { + if (line % 5 == 4) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 5 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 5 == 1) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } else if (line % 5 == 2) { + pc = prev_line; + delta_pc = 1; + delta_addr = 1; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= (OFFSET[i-2] + WIDTH[i-2]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + PC[i] = pc; + DELTA_PC[i] = delta_pc; + DELTA_ADDR[i] = delta_addr; + int flags = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + flags += sel[j] * 2**j; + } + flags += is_write * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + FLAGS[i] = flags; + } + + // Ensure the program is being followed by the MemAlign + lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); +} \ No newline at end of file diff --git a/state-machines/mem/src/input_data_sm.rs b/state-machines/mem/src/input_data_sm.rs new file mode 100644 index 00000000..220fc33f --- /dev/null +++ b/state-machines/mem/src/input_data_sm.rs @@ -0,0 +1,377 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{ + MemAirValues, MemInput, MemModule, MemPreviousSegment, MEMORY_MAX_DIFF, MEM_BYTES_BITS, +}; +use num_bigint::BigInt; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; +use zisk_core::{INPUT_ADDR, MAX_INPUT_SIZE}; +use zisk_pil::{InputDataTrace, INPUT_DATA_AIR_IDS, ZISK_AIRGROUP_ID}; + +const INPUT_W_ADDR_INIT: u32 = INPUT_ADDR as u32 >> MEM_BYTES_BITS; +const INPUT_W_ADDR_END: u32 = (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 >> MEM_BYTES_BITS; + +#[allow(clippy::assertions_on_constants)] +const _: () = { + assert!( + (MAX_INPUT_SIZE - 1) >> MEM_BYTES_BITS as u64 <= MEMORY_MAX_DIFF, + "INPUT_DATA is too large" + ); + assert!( + INPUT_ADDR + MAX_INPUT_SIZE - 1 <= 0xFFFF_FFFF, + "INPUT_DATA memory exceeds the 32-bit addressable range" + ); +}; + +pub struct InputDataSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + num_rows: usize, + // Count of registered predecessors + registered_predecessors: AtomicU32, +} + +#[allow(unused, unused_variables)] +impl InputDataSM { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, INPUT_DATA_AIR_IDS[0]); + let input_data_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + num_rows: air.num_rows(), + registered_predecessors: AtomicU32::new(0), + }; + let input_data_sm = Arc::new(input_data_sm); + + wcm.register_component( + input_data_sm.clone(), + Some(ZISK_AIRGROUP_ID), + Some(INPUT_DATA_AIR_IDS), + ); + std.register_predecessor(); + + input_data_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } + } + + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs + // memory only need to process these special inputs, but inputs no change. At end of + // inputs proxy add an extra internal input to jump to last address + + let air_id = INPUT_DATA_AIR_IDS[0]; + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, air_id); + let air_rows = air.num_rows(); + + // at least one row to go + let count = inputs.len(); + let count_rem = count % air_rows; + let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; + + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_segments]); + let mut global_idxs = vec![0; num_segments]; + + #[allow(clippy::needless_range_loop)] + for i in 0..num_segments { + // TODO: Review + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1) + { + let trace: InputDataTrace<'_, _> = InputDataTrace::new(air_rows); + let mut buffer = trace.buffer.unwrap(); + prover_buffers.lock().unwrap()[i] = buffer; + global_idxs[i] = global_idx; + } + } + + #[allow(clippy::needless_range_loop)] + for segment_id in 0..num_segments { + let is_last_segment = segment_id == num_segments - 1; + let input_offset = segment_id * air_rows; + let previous_segment = if (segment_id == 0) { + MemPreviousSegment { addr: INPUT_W_ADDR_INIT, step: 0, value: 0 } + } else { + MemPreviousSegment { + addr: inputs[input_offset - 1].addr, + step: inputs[input_offset - 1].step, + value: inputs[input_offset - 1].value, + } + }; + let input_end = + if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; + let mem_ops = &inputs[input_offset..input_end]; + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + &previous_segment, + prover_buffer, + air_rows, + global_idxs[segment_id], + ); + } + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `ZiskRequiredMemory` inputs + #[allow(clippy::too_many_arguments)] + pub fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + previous_segment: &MemPreviousSegment, + mut prover_buffer: Vec, + air_mem_rows: usize, + global_idx: usize, + ) -> Result<(), Box> { + assert!( + !mem_ops.is_empty() && mem_ops.len() <= air_mem_rows, + "InputDataSM: mem_ops.len()={} out of range {}", + mem_ops.len(), + air_mem_rows + ); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + //println! {"InputDataSM::prove_instance() mem_ops.len={} prover_buffer.len={} + // air.num_rows={}", mem_ops.len(), prover_buffer.len(), air.num_rows()}; + let mut trace = + InputDataTrace::::map_buffer(&mut prover_buffer, air_mem_rows, 0).unwrap(); + + let mut range_check_data: Vec = vec![0; 1 << 16]; + + let mut air_values = MemAirValues { + segment_id: segment_id as u32, + is_first_segment: segment_id == 0, + is_last_segment, + previous_segment_addr: previous_segment.addr, + previous_segment_step: previous_segment.step, + previous_segment_value: [ + previous_segment.value as u32, + (previous_segment.value >> 32) as u32, + ], + ..MemAirValues::default() + }; + + // range of instance + let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + self.std.range_check( + F::from_canonical_u32(previous_segment.addr - INPUT_W_ADDR_INIT + 1), + F::one(), + range_id, + ); + + // Fill the remaining rows + let mut last_addr: u32 = previous_segment.addr; + let mut last_step: u64 = previous_segment.step; + let mut last_value: u64 = previous_segment.value; + + for (i, mem_op) in mem_ops.iter().enumerate() { + trace[i].addr = F::from_canonical_u32(mem_op.addr); + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::from_bool(!mem_op.is_internal); + + let value = mem_op.value; + let value_words = self.get_u16_values(value); + for j in 0..4 { + range_check_data[value_words[j] as usize] += 1; + trace[i].value_word[j] = F::from_canonical_u16(value_words[j]); + } + + let addr_changes = last_addr != mem_op.addr; + trace[i].addr_changes = + if addr_changes || (i == 0 && segment_id == 0) { F::one() } else { F::zero() }; + + last_addr = mem_op.addr; + last_step = mem_op.step; + last_value = mem_op.value; + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0 + let last_row_idx = mem_ops.len() - 1; + let addr = trace[last_row_idx].addr; + let value = trace[last_row_idx].value_word; + + let padding_size = air_mem_rows - mem_ops.len(); + for i in mem_ops.len()..air_mem_rows { + last_step += 1; + + // TODO CHECK + // trace[i].mem_segment = segment_id_field; + // trace[i].mem_last_segment = is_last_segment_field; + + trace[i].addr = addr; + trace[i].step = F::from_canonical_u64(last_step); + trace[i].sel = F::zero(); + + trace[i].value_word = value; + + trace[i].addr_changes = F::zero(); + } + + air_values.segment_last_addr = last_addr; + air_values.segment_last_step = last_step; + air_values.segment_last_value[0] = last_value as u32; + air_values.segment_last_value[1] = (last_value >> 32) as u32; + + self.std.range_check( + F::from_canonical_u32(INPUT_W_ADDR_END - last_addr + 1), + F::one(), + range_id, + ); + + // range of chunks + let range_id = self.std.get_range(BigInt::from(0), BigInt::from((1 << 16) - 1), None); + for (value, &multiplicity) in range_check_data.iter().enumerate() { + if (multiplicity == 0) { + continue; + } + + self.std.range_check( + F::from_canonical_usize(value), + F::from_canonical_u64(multiplicity), + range_id, + ); + } + for value_chunk in &value { + self.std.range_check(*value_chunk, F::from_canonical_usize(padding_size), range_id); + } + + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + let mut air_instance = AirInstance::new( + self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + INPUT_DATA_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + self.set_airvalues("InputData", &mut air_instance, &air_values); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + + Ok(()) + } + + fn get_u16_values(&self, value: u64) -> [u16; 4] { + [value as u16, (value >> 16) as u16, (value >> 32) as u16, (value >> 48) as u16] + } + fn set_airvalues( + &self, + prefix: &str, + air_instance: &mut AirInstance, + air_values: &MemAirValues, + ) { + air_instance.set_airvalue( + format!("{}.segment_id", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_id), + ); + air_instance.set_airvalue( + format!("{}.is_first_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_first_segment), + ); + air_instance.set_airvalue( + format!("{}.is_last_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_last_segment), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.previous_segment_addr), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.previous_segment_step), + ); + air_instance.set_airvalue( + format!("{}.segment_last_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_last_addr), + ); + air_instance.set_airvalue( + format!("{}.segment_last_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.segment_last_step), + ); + let count = air_values.previous_segment_value.len(); + for i in 0..count { + air_instance.set_airvalue( + format!("{}.previous_segment_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.previous_segment_value[i]), + ); + air_instance.set_airvalue( + format!("{}.segment_last_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.segment_last_value[i]), + ); + } + } +} + +impl MemModule for InputDataSM { + fn send_inputs(&self, mem_op: &[MemInput]) { + self.prove(mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(INPUT_ADDR as u32, (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32)] + } + fn get_flush_input_size(&self) -> u32 { + self.num_rows as u32 + } +} + +impl WitnessComponent for InputDataSM {} diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 67bf225c..3c42869b 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,9 +1,23 @@ -mod mem; -mod mem_aligned; -mod mem_traces; -mod mem_unaligned; +mod input_data_sm; +mod mem_align_rom_sm; +mod mem_align_sm; +mod mem_constants; +mod mem_helpers; +mod mem_module; +mod mem_proxy; +mod mem_proxy_engine; +mod mem_sm; +mod mem_unmapped; +mod rom_data; -pub use mem::*; -pub use mem_aligned::*; -pub use mem_traces::*; -pub use mem_unaligned::*; +pub use input_data_sm::*; +pub use mem_align_rom_sm::*; +pub use mem_align_sm::*; +pub use mem_constants::*; +pub use mem_helpers::*; +pub use mem_module::*; +pub use mem_proxy::*; +pub use mem_proxy_engine::*; +pub use mem_sm::*; +pub use mem_unmapped::*; +pub use rom_data::*; diff --git a/state-machines/mem/src/mem.rs b/state-machines/mem/src/mem.rs deleted file mode 100644 index 065b1841..00000000 --- a/state-machines/mem/src/mem.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use crate::{MemAlignedSM, MemUnalignedSM}; -use p3_field::Field; -use rayon::Scope; -use sm_common::{MemOp, MemUnalignedOp, OpResult, Provable}; -use zisk_core::ZiskRequiredMemory; - -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; - -#[allow(dead_code)] -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -#[allow(dead_code)] -pub struct MemSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs_aligned: Mutex>, - inputs_unaligned: Mutex>, - - // Secondary State machines - mem_aligned_sm: Arc, - mem_unaligned_sm: Arc, -} - -impl MemSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = MemAlignedSM::new(wcm.clone()); - let mem_unaligned_sm = MemUnalignedSM::new(wcm.clone()); - - let mem_sm = Self { - registered_predecessors: AtomicU32::new(0), - inputs_aligned: Mutex::new(Vec::new()), - inputs_unaligned: Mutex::new(Vec::new()), - mem_aligned_sm: mem_aligned_sm.clone(), - mem_unaligned_sm: mem_unaligned_sm.clone(), - }; - let mem_sm = Arc::new(mem_sm); - - wcm.register_component(mem_sm.clone(), None, None); - - // For all the secondary state machines, register the main state machine as a predecessor - mem_sm.mem_aligned_sm.register_predecessor(); - mem_sm.mem_unaligned_sm.register_predecessor(); - - mem_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - - self.mem_aligned_sm.unregister_predecessor::(scope); - self.mem_unaligned_sm.unregister_predecessor::(scope); - } - } -} - -impl WitnessComponent for MemSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemSM { - fn calculate( - &self, - _operation: ZiskRequiredMemory, - ) -> Result> { - unimplemented!() - } - - fn prove(&self, _operations: &[ZiskRequiredMemory], _drain: bool, _scope: &Scope) { - // TODO! - } - - fn calculate_prove( - &self, - _operation: ZiskRequiredMemory, - _drain: bool, - _scope: &Scope, - ) -> Result> { - unimplemented!() - } -} diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs new file mode 100644 index 00000000..486c05dd --- /dev/null +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -0,0 +1,214 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, + }, +}; + +use log::info; +use p3_field::PrimeField; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use zisk_pil::{MemAlignRomRow, MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; + +#[derive(Debug, Clone, Copy)] +pub enum MemOp { + OneRead, + OneWrite, + TwoReads, + TwoWrites, +} + +const OP_SIZES: [u64; 4] = [2, 3, 3, 5]; +const ONE_WORD_COMBINATIONS: u64 = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 +const TWO_WORD_COMBINATIONS: u64 = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 + +pub struct MemAlignRomSM { + // Witness computation manager + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Rom data + num_rows: usize, + multiplicity: Mutex>, // row_num -> multiplicity +} + +#[derive(Debug)] +pub enum ExtensionTableSMErr { + InvalidOpcode, +} + +impl MemAlignRomSM { + const MY_NAME: &'static str = "MemAlignRom"; + + pub fn new(wcm: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let num_rows = air.num_rows(); + + let mem_align_rom = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + num_rows, + multiplicity: Mutex::new(HashMap::with_capacity(num_rows)), + }; + let mem_align_rom = Arc::new(mem_align_rom); + wcm.register_component( + mem_align_rom.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_ROM_AIR_IDS), + ); + + mem_align_rom + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.create_air_instance(); + } + } + + pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 { + // Get the table offset + let (table_offset, one_word) = match opcode { + MemOp::OneRead => (1, true), + + MemOp::OneWrite => (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true), + + MemOp::TwoReads => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], + false, + ), + + MemOp::TwoWrites => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2], + false, + ), + }; + + // Get the first row index + let first_row_idx = Self::get_first_row_idx(opcode, offset, width, table_offset, one_word); + + // Based on the program size, return the row indices + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; + for i in 0..op_size { + let row_idx = first_row_idx + i; + // Check whether the row index is within the bounds + debug_assert!(row_idx < self.num_rows as u64); + // Update the multiplicity + self.update_multiplicity_by_row_idx(row_idx, 1); + } + + first_row_idx + } + + fn get_first_row_idx( + opcode: MemOp, + offset: usize, + width: usize, + table_offset: u64, + one_word: bool, + ) -> u64 { + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; + + // Go to the actual operation + let mut first_row_idx = table_offset; + + // Go to the actual offset + let first_valid_offset = if one_word { 0 } else { 1 }; + for i in first_valid_offset..offset { + let possible_widths = Self::calculate_possible_widths(one_word, i); + first_row_idx += op_size * possible_widths.len() as u64; + } + + // Go to the right width + let width_idx = Self::calculate_possible_widths(one_word, offset) + .iter() + .position(|&w| w == width) + .unwrap_or_else(|| panic!("Invalid width offset:{} width:{}", offset, width)); + first_row_idx += op_size * width_idx as u64; + + first_row_idx + } + + fn calculate_possible_widths(one_word: bool, offset: usize) -> Vec { + // Calculate the ROM rows based on the requested opcode, offset, and width + match one_word { + true => match offset { + x if x <= 4 => vec![1, 2, 4], + x if x <= 6 => vec![1, 2], + 7 => vec![1], + _ => panic!("Invalid offset={}", offset), + }, + false => match offset { + 0 => panic!("Invalid offset={}", offset), + x if x <= 4 => vec![8], + x if x <= 6 => vec![4, 8], + 7 => vec![2, 4, 8], + _ => panic!("Invalid offset={}", offset), + }, + } + } + + pub fn update_padding_row(&self, padding_len: u64) { + // Update entry at the padding row (pos = 0) with the given padding length + self.update_multiplicity_by_row_idx(0, padding_len); + } + + pub fn update_multiplicity_by_row_idx(&self, row_idx: u64, mul: u64) { + let mut multiplicity = self.multiplicity.lock().unwrap(); + *multiplicity.entry(row_idx).or_insert(0) += mul; + } + + pub fn create_air_instance(&self) { + // Get the contexts + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + // Get the Mem Align ROM AIR + let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let air_mem_align_rom_rows = air_mem_align_rom.num_rows(); + + let mut trace_buffer: MemAlignRomTrace<'_, _> = + MemAlignRomTrace::new(air_mem_align_rom_rows); + + // Initialize the trace buffer to zero + for i in 0..air_mem_align_rom_rows { + trace_buffer[i] = MemAlignRomRow { multiplicity: F::zero() }; + } + + // Fill the trace buffer with the multiplicity values + if let Ok(multiplicity) = self.multiplicity.lock() { + for (row_idx, multiplicity) in multiplicity.iter() { + trace_buffer[*row_idx as usize] = + MemAlignRomRow { multiplicity: F::from_canonical_u64(*multiplicity) }; + } + } + + info!("{}: ··· Creating Mem Align Rom instance", Self::MY_NAME,); + + let air_instance = AirInstance::new( + sctx, + ZISK_AIRGROUP_ID, + MEM_ALIGN_ROM_AIR_IDS[0], + None, + trace_buffer.buffer.unwrap(), + ); + pctx.air_instance_repo.add_air_instance(air_instance, None); + } +} + +impl WitnessComponent for MemAlignRomSM {} diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs new file mode 100644 index 00000000..7433c007 --- /dev/null +++ b/state-machines/mem/src/mem_align_sm.rs @@ -0,0 +1,1015 @@ +use core::panic; +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use log::info; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; + +use crate::{MemAlignInput, MemAlignRomSM, MemOp}; + +const RC: usize = 2; +const CHUNK_NUM: usize = 8; +const CHUNKS_BY_RC: usize = CHUNK_NUM / RC; +const CHUNK_BITS: usize = 8; +const RC_BITS: u64 = (CHUNKS_BY_RC * CHUNK_BITS) as u64; +const RC_MASK: u64 = (1 << RC_BITS) - 1; +const OFFSET_MASK: u32 = 0x07; +const OFFSET_BITS: u32 = 3; +const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; + +const fn generate_allowed_offsets() -> [u8; CHUNK_NUM] { + let mut offsets = [0; CHUNK_NUM]; + let mut i = 0; + while i < CHUNK_NUM { + offsets[i] = i as u8; + i += 1; + } + offsets +} + +const ALLOWED_OFFSETS: [u8; CHUNK_NUM] = generate_allowed_offsets(); +const ALLOWED_WIDTHS: [u8; 4] = [1, 2, 4, 8]; +const DEFAULT_OFFSET: u64 = 0; +const DEFAULT_WIDTH: u64 = 8; + +pub struct MemAlignResponse { + pub more_addr: bool, + pub step: u64, + pub value: Option, +} +pub struct MemAlignSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Computed row information + rows: Mutex>>, + #[cfg(feature = "debug_mem_align")] + num_computed_rows: Mutex, + + // Secondary State machines + mem_align_rom_sm: Arc>, +} + +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_align")] + { + info!(concat!("MemAlign: ",$prefix), $($arg)*); + } + }; +} + +impl MemAlignSM { + const MY_NAME: &'static str = "MemAlign"; + + pub fn new( + wcm: Arc>, + std: Arc>, + mem_align_rom_sm: Arc>, + ) -> Arc { + let mem_align_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + registered_predecessors: AtomicU32::new(0), + rows: Mutex::new(Vec::new()), + #[cfg(feature = "debug_mem_align")] + num_computed_rows: Mutex::new(0), + mem_align_rom_sm, + }; + let mem_align_sm = Arc::new(mem_align_sm); + + wcm.register_component( + mem_align_sm.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_AIR_IDS), + ); + + // Register the predecessors + std.register_predecessor(); + mem_align_sm.mem_align_rom_sm.register_predecessor(); + + mem_align_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + + // If there are remaining rows, generate the last instance + if let Ok(mut rows) = self.rows.lock() { + // Get the Mem Align AIR + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + let rows_len = rows.len(); + debug_assert!(rows_len <= air_mem_align.num_rows()); + + let drained_rows = rows.drain(..rows_len).collect::>(); + + self.fill_new_air_instance(&drained_rows); + } + + self.mem_align_rom_sm.unregister_predecessor(); + self.std.unregister_predecessor(pctx, None); + } + } + + #[inline(always)] + pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { + let addr = input.addr; + let width = input.width; + + // Compute the width + debug_assert!( + ALLOWED_WIDTHS.contains(&width), + "Width={} is not allowed. Allowed widths are {:?}", + width, + ALLOWED_WIDTHS + ); + let width = width as usize; + + // Compute the offset + let offset = (addr & OFFSET_MASK) as u8; + debug_assert!( + ALLOWED_OFFSETS.contains(&offset), + "Offset={} is not allowed. Allowed offsets are {:?}", + offset, + ALLOWED_OFFSETS + ); + let offset = offset as usize; + + #[cfg(feature = "debug_mem_align")] + let num_rows = self.num_computed_rows.lock().unwrap(); + match (input.is_write, offset + width > CHUNK_NUM) { + (false, false) => { + /* RV with offset=2, width=4 + +----+----+====+====+====+====+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+====+====+====+====+----+----+ + ⇓ + +----+----+====+====+====+====+----+----+ + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | + +----+----+====+====+====+====+----+----+ + */ + debug_assert!(phase == 0); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Get the aligned address + let addr_read = addr >> OFFSET_BITS; + + // Get the aligned value + let value_read = input.mem_values[phase]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneRead, offset, width); + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); + if i >= offset && i < offset + width { + read_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_read = value_read; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nOne Word Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 1], + input, + phase, + value_read.to_le_bytes(), + value.to_le_bytes(), + [ + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[read_row, value_row]); + + MemAlignResponse { more_addr: false, step, value: None } + } + (true, false) => { + /* RWV with offset=3, width=4 + +----+----+----+====+====+====+====+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+====+====+====+====+----+ + ⇓ + +----+----+----+====+====+====+====+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+====+====+====+====+----+ + ⇓ + +----+----+----+====+====+====+====+----+ + | V5 | V6 | V7 | V0 | V1 | V2 | V3 | V4 | + +----+----+----+====+====+====+====+----+ + */ + debug_assert!(phase == 0); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Get the aligned address + let addr_read = addr >> OFFSET_BITS; + + // Get the aligned value + let value_read = input.mem_values[phase]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneWrite, offset, width); + + // Compute the write value + let value_write = { + // with:1 offset:4 + let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_read & !mask) | value_to_write + }; + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); + if i < offset || i >= offset + width { + read_row.sel[i] = F::from_bool(true); + } + + write_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_write, i, 0)); + if i >= offset && i < offset + width { + write_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = { + if i >= offset && i < offset + width { + write_row.reg[i] + } else { + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)) + } + }; + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_read = value_read; + let mut _value_write = value_write; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + write_row.value[i] = F::from_canonical_u64(_value_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value_write >>= RC_BITS; + _value >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nOne Word Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value Write: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Write: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 2], + input, + phase, + value_read.to_le_bytes(), + value_write.to_le_bytes(), + value.to_le_bytes(), + [ + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], + [ + write_row.sel[0], write_row.sel[1], write_row.sel[2], write_row.sel[3], + write_row.sel[4], write_row.sel[5], write_row.sel[6], write_row.sel[7], + write_row.wr, write_row.reset, write_row.sel_up_to_down, write_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[read_row, write_row, value_row]); + + MemAlignResponse { more_addr: false, step, value: Some(value_write) } + } + (false, true) => { + /* RVR with offset=5, width=8 + +----+----+----+----+----+====+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+====+====+====+ + ⇓ + +====+====+====+====+====+====+====+====+ + | V3 | V4 | V5 | V6 | V7 | V0 | V1 | V2 | + +====+====+====+====+====+====+====+====+ + ⇓ + +====+====+====+====+====+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+====+====+====+----+----+----+ + */ + debug_assert!(phase == 0 || phase == 1); + + match phase { + // If phase == 0, do nothing, just ask for more + 0 => MemAlignResponse { more_addr: true, step: input.step, value: None }, + + // Otherwise, do the RVR + 1 => { + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the remaining bytes + let rem_bytes = (offset + width) % CHUNK_NUM; + + // Get the aligned address + let addr_first_read = addr >> OFFSET_BITS; + let addr_second_read = addr_first_read + 1; + + // Get the aligned value + let value_first_read = input.mem_values[0]; + let value_second_read = input.mem_values[1]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::TwoReads, offset, width); + + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i >= offset { + first_read_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i < rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_first_read = value_first_read; + let mut _value = value; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nTwo Words Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Flags First Read: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 2], + input, + phase, + value_first_read.to_le_bytes(), + value.to_le_bytes(), + value_second_read.to_le_bytes(), + [ + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], + [ + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[first_read_row, value_row, second_read_row]); + + MemAlignResponse { more_addr: false, step, value: None } + } + _ => panic!("Invalid phase={}", phase), + } + } + (true, true) => { + /* RWVWR with offset=6, width=4 + +----+----+----+----+----+----+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+====+====+ + | V2 | V3 | V4 | V5 | V6 | V7 | V0 | V1 | + +====+====+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +====+====+----+----+----+----+----+----+ + ⇓ + +====+====+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+----+----+----+----+----+----+ + */ + debug_assert!(phase == 0 || phase == 1); + + match phase { + // If phase == 0, compute the resulting write value and ask for more + 0 => { + // Unaligned memory op information thrown into the bus + let value = input.value; + let step = input.step; + + // Get the aligned value + let value_first_read = input.mem_values[0]; + + // Compute the write value + let value_first_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + MemAlignResponse { more_addr: true, step, value: Some(value_first_write) } + } + // Otherwise, do the RWVRW + 1 => { + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the shift + let rem_bytes = (offset + width) % CHUNK_NUM; + + // Get the aligned address + let addr_first_read_write = addr >> OFFSET_BITS; + let addr_second_read_write = addr_first_read_write + 1; + + // Get the first aligned value + let value_first_read = input.mem_values[0]; + + // Recompute the first write value + let value_first_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + // Get the second aligned value + let value_second_read = input.mem_values[1]; + + // Compute the second write value + let value_second_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let mask: u64 = (1 << (rem_bytes * CHUNK_BITS)) - 1; + + // Get the first width bytes of the unaligned value + let value_to_write = (value >> (width_norm * CHUNK_BITS)) & mask; + + // Write zeroes to value_read from 0 to offset + width + // and add the value to write to the value read + (value_second_read & !mask) | value_to_write + }; + + // Get the next pc + let next_pc = self.mem_align_rom_sm.calculate_next_pc( + MemOp::TwoWrites, + offset, + width, + ); + + // RWVWR + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut first_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_second_read_write), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 2), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 3), + reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i < offset { + first_read_row.sel[i] = F::from_bool(true); + } + + first_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_write, i, 0)); + if i >= offset { + first_write_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = { + if i < rem_bytes { + second_write_row.reg[i] + } else if i >= offset { + first_write_row.reg[i] + } else { + F::from_canonical_u64(Self::get_byte( + value, + i, + CHUNK_NUM - offset, + )) + } + }; + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + + second_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); + if i < rem_bytes { + second_write_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i >= rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_first_read = value_first_read; + let mut _value_first_write = value_first_write; + let mut _value = value; + let mut _value_second_write = value_second_write; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + first_write_row.value[i] = + F::from_canonical_u64(_value_first_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_write_row.value[i] = + F::from_canonical_u64(_value_second_write & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value_first_write >>= RC_BITS; + _value >>= RC_BITS; + _value_second_write >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nTwo Words Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value First Write: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Value Second Write: {:?}\n\ + Flags First Read: {:?}\n\ + Flags First Write: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Write: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 4], + input, + phase, + value_first_read.to_le_bytes(), + value_first_write.to_le_bytes(), + value.to_le_bytes(), + value_second_write.to_le_bytes(), + value_second_read.to_le_bytes(), + [ + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], + [ + first_write_row.sel[0], first_write_row.sel[1], first_write_row.sel[2], first_write_row.sel[3], + first_write_row.sel[4], first_write_row.sel[5], first_write_row.sel[6], first_write_row.sel[7], + first_write_row.wr, first_write_row.reset, first_write_row.sel_up_to_down, first_write_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], + [ + second_write_row.sel[0], second_write_row.sel[1], second_write_row.sel[2], second_write_row.sel[3], + second_write_row.sel[4], second_write_row.sel[5], second_write_row.sel[6], second_write_row.sel[7], + second_write_row.wr, second_write_row.reset, second_write_row.sel_up_to_down, second_write_row.sel_down_to_up + ], + [ + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[ + first_read_row, + first_write_row, + value_row, + second_write_row, + second_read_row, + ]); + + MemAlignResponse { more_addr: false, step, value: Some(value_second_write) } + } + _ => panic!("Invalid phase={}", phase), + } + } + } + } + + fn get_byte(value: u64, index: usize, offset: usize) -> u64 { + let chunk = (offset + index) % CHUNK_NUM; + (value >> (chunk * CHUNK_BITS)) & CHUNK_BITS_MASK + } + + pub fn prove(&self, computed_rows: &[MemAlignRow]) { + if let Ok(mut rows) = self.rows.lock() { + rows.extend_from_slice(computed_rows); + + #[cfg(feature = "debug_mem_align")] + { + let mut num_rows = self.num_computed_rows.lock().unwrap(); + *num_rows += computed_rows.len(); + drop(num_rows); + } + + let pctx = self.wcm.get_pctx(); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + while rows.len() >= air_mem_align.num_rows() { + let num_drained = std::cmp::min(air_mem_align.num_rows(), rows.len()); + let drained_rows = rows.drain(..num_drained).collect::>(); + + self.fill_new_air_instance(&drained_rows); + } + } + } + + fn fill_new_air_instance(&self, rows: &[MemAlignRow]) { + // Get the proof context + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + + // Get the Mem Align AIR + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + let air_mem_align_rows = air_mem_align.num_rows(); + let rows_len = rows.len(); + + // You cannot feed to the AIR more rows than it has + debug_assert!(rows_len <= air_mem_align_rows); + + // Get the execution and setup context + let sctx = wcm.get_sctx(); + + let mut trace_buffer: MemAlignTrace<'_, _> = MemAlignTrace::new(air_mem_align_rows); + + let mut reg_range_check: Vec = vec![0; 1 << CHUNK_BITS]; + // Add the input rows to the trace + for (i, &row) in rows.iter().enumerate() { + // Store the entire row + trace_buffer[i] = row; + // Store the value of all reg columns so that they can be range checked + for j in 0..CHUNK_NUM { + let element = + row.reg[j].as_canonical_biguint().to_usize().expect("Cannot convert to usize"); + reg_range_check[element] += 1; + } + } + + // Pad the remaining rows with trivially satisfying rows + let padding_row = MemAlignRow:: { reset: F::from_bool(true), ..Default::default() }; + let padding_size = air_mem_align_rows - rows_len; + + // Store the padding rows + for i in rows_len..air_mem_align_rows { + trace_buffer[i] = padding_row; + } + + // Store the value of all padding reg columns so that they can be range checked + for _ in 0..CHUNK_NUM { + reg_range_check[0] += padding_size as u64; + } + + // Perform the range checks + let std = self.std.clone(); + let range_id = std.get_range(BigInt::from(0), BigInt::from(CHUNK_BITS_MASK), None); + for (value, &multiplicity) in reg_range_check.iter().enumerate() { + std.range_check( + F::from_canonical_usize(value), + F::from_canonical_u64(multiplicity), + range_id, + ); + } + + // Compute the program multiplicity + let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + mem_align_rom_sm.update_padding_row(padding_size as u64); + + info!( + "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + rows_len, + air_mem_align_rows, + rows_len as f64 / air_mem_align_rows as f64 * 100.0 + ); + + // Add a new Mem Align instance + let air_instance = AirInstance::new( + sctx, + ZISK_AIRGROUP_ID, + MEM_ALIGN_AIR_IDS[0], + None, + trace_buffer.buffer.unwrap(), + ); + pctx.air_instance_repo.add_air_instance(air_instance, None); + } +} + +impl WitnessComponent for MemAlignSM {} diff --git a/state-machines/mem/src/mem_aligned.rs b/state-machines/mem/src/mem_aligned.rs deleted file mode 100644 index 1a126e3c..00000000 --- a/state-machines/mem/src/mem_aligned.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_ALIGN_AIR_IDS}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct MemAlignedSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -#[allow(unused, unused_variables)] -impl MemAlignedSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); - - wcm.register_component( - mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_ALIGN_AIR_IDS), - ); - - mem_aligned_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - fn read( - &self, - _addr: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } - - fn write( - &self, - _addr: u64, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } -} - -impl WitnessComponent for MemAlignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemAlignedSM { - fn calculate(&self, operation: MemOp) -> Result> { - match operation { - MemOp::Read(addr) => self.read(addr), - MemOp::Write(addr, val) => self.write(addr, val), - } - } - - fn prove(&self, operations: &[MemOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: MemOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 4e177ee3..9165edd1 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -1,12 +1,17 @@ -pub const MEM_ADDR_MASK: u64 = 0xFFFF_FFFF_FFFF_FFF8; -pub const MEM_BYTES: u64 = 8; +pub const MEM_ADDR_MASK: u32 = 0xFFFF_FFF8; +pub const MEM_BYTES_BITS: u32 = 3; +pub const MEM_BYTES: u32 = 1 << MEM_BYTES_BITS; +pub const MEM_STEP_BASE: u64 = 1; pub const MAX_MEM_STEP_OFFSET: u64 = 2; -pub const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; +pub const MAX_MEM_OPS_BY_STEP_OFFSET: u64 = 2; +pub const MAX_MEM_OPS_BY_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * MAX_MEM_OPS_BY_STEP_OFFSET; -pub const MEM_STEP_BITS: u64 = 34; // with step_slot = 8 => 2GB steps ( -pub const MEM_STEP_MASK: u64 = (1 << MEM_STEP_BITS) - 1; // 256 MB -pub const MEM_ADDR_BITS: u64 = 64 - MEM_STEP_BITS; +pub const MAX_MAIN_STEP: u64 = 0x1FFF_FFFF_FFFF_FFFF; +pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * MAX_MAIN_STEP + + MAX_MEM_OPS_BY_STEP_OFFSET * MAX_MEM_STEP_OFFSET; -pub const MAX_MEM_STEP: u64 = (1 << MEM_STEP_BITS) - 1; -pub const MAX_MEM_ADDR: u64 = (1 << MEM_ADDR_BITS) - 1; +pub const MAX_MEM_ADDR: u64 = 0xFFFF_FFFF; + +pub const MEMORY_MAX_DIFF: u64 = 1 << 24; diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs index ac4ca198..8e70b537 100644 --- a/state-machines/mem/src/mem_helpers.rs +++ b/state-machines/mem/src/mem_helpers.rs @@ -1,7 +1,10 @@ -use crate::MemAlignResponse; +use crate::{ + MemAlignResponse, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_OPS_BY_STEP_OFFSET, MEM_STEP_BASE, +}; use std::fmt; use zisk_core::ZiskRequiredMemory; +#[allow(dead_code)] fn format_u64_hex(value: u64) -> String { let hex_str = format!("{:016x}", value); hex_str @@ -12,54 +15,73 @@ fn format_u64_hex(value: u64) -> String { .join("_") } +#[derive(Debug, Clone)] +pub struct MemAlignInput { + pub addr: u32, + pub is_write: bool, + pub width: u8, + pub step: u64, + pub value: u64, + pub mem_values: [u64; 2], +} + +#[derive(Debug, Clone)] +pub struct MemInput { + pub addr: u32, // address in word native format means byte_address / MEM_BYTES + pub is_write: bool, // it's a write operation + pub is_internal: bool, // internal operation, don't send this operation to bus + pub step: u64, // mem_step = f(main_step, main_step_offset) + pub value: u64, // value to read or write +} + +impl MemAlignInput { + pub fn new( + addr: u32, + is_write: bool, + width: u8, + step: u64, + value: u64, + mem_values: [u64; 2], + ) -> Self { + MemAlignInput { addr, is_write, width, step, value, mem_values } + } + pub fn from(mem_external_op: &ZiskRequiredMemory, mem_values: &[u64; 2]) -> Self { + match mem_external_op { + ZiskRequiredMemory::Basic { step, value, address, is_write, width, step_offset } => { + MemAlignInput { + addr: *address, + is_write: *is_write, + step: MemHelpers::main_step_to_address_step(*step, *step_offset), + width: *width, + value: *value, + mem_values: [mem_values[0], mem_values[1]], + } + } + ZiskRequiredMemory::Extended { values: _, address: _ } => { + panic!("MemAlignInput::from() called with extended instance") + } + } + } +} + +pub struct MemHelpers {} + +impl MemHelpers { + pub fn main_step_to_address_step(step: u64, step_offset: u8) -> u64 { + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64 + } +} + impl fmt::Debug for MemAlignResponse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "more:{0} step:{1} value:{2:016X}({2:})", - self.more_address, + self.more_addr, self.step, self.value.unwrap_or(0) ) } } - -pub fn mem_align_call( - mem_op: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: u8, -) -> MemAlignResponse { - // DEBUG: only for testing - let offset = (mem_op.address & 0x7) * 8; - let width = (mem_op.width as u64) * 8; - let double_address = (offset + width as u32) > 64; - let mem_value = mem_values[phase as usize]; - let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); - if mem_op.is_write { - if phase == 0 { - MemAlignResponse { - more_address: double_address, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) - | ((mem_op.value & mask) << offset), - ), - } - } else { - MemAlignResponse { - more_address: false, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) - | ((mem_op.value & mask) >> (128 - (offset + width as u32))), - ), - } - } - } else { - MemAlignResponse { - more_address: double_address && phase == 0, - step: mem_op.step + 1, - value: None, - } - } -} diff --git a/state-machines/mem/src/mem_module.rs b/state-machines/mem/src/mem_module.rs new file mode 100644 index 00000000..59308fd3 --- /dev/null +++ b/state-machines/mem/src/mem_module.rs @@ -0,0 +1,31 @@ +use crate::{MemHelpers, MemInput, MEM_BYTES}; +use zisk_core::ZiskRequiredMemory; + +impl MemInput { + pub fn new(addr: u32, is_write: bool, step: u64, value: u64, is_internal: bool) -> Self { + MemInput { addr, is_write, step, value, is_internal } + } + pub fn from(mem_op: &ZiskRequiredMemory) -> Self { + match mem_op { + ZiskRequiredMemory::Basic { step, value, address, is_write, width, step_offset } => { + debug_assert_eq!(*width, MEM_BYTES as u8); + MemInput { + addr: address >> 3, + is_write: *is_write, + is_internal: false, + step: MemHelpers::main_step_to_address_step(*step, *step_offset), + value: *value, + } + } + ZiskRequiredMemory::Extended { values: _, address: _ } => { + panic!("MemInput::from() called with an extended instance"); + } + } + } +} + +pub trait MemModule: Send + Sync { + fn send_inputs(&self, mem_op: &[MemInput]); + fn get_addr_ranges(&self) -> Vec<(u32, u32)>; + fn get_flush_input_size(&self) -> u32; +} diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs new file mode 100644 index 00000000..a5bcf320 --- /dev/null +++ b/state-machines/mem/src/mem_proxy.rs @@ -0,0 +1,79 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, +}; + +use crate::{InputDataSM, MemAlignRomSM, MemAlignSM, MemProxyEngine, MemSM, RomDataSM}; +use p3_field::PrimeField; +use pil_std_lib::Std; +use zisk_core::ZiskRequiredMemory; + +use proofman::{WitnessComponent, WitnessManager}; + +pub struct MemProxy { + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Secondary State machines + mem_sm: Arc>, + mem_align_sm: Arc>, + mem_align_rom_sm: Arc>, + input_data_sm: Arc>, + rom_data_sm: Arc>, +} + +impl MemProxy { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); + let mem_align_sm = MemAlignSM::new(wcm.clone(), std.clone(), mem_align_rom_sm.clone()); + let mem_sm = MemSM::new(wcm.clone(), std.clone()); + let input_data_sm = InputDataSM::new(wcm.clone(), std.clone()); + let rom_data_sm = RomDataSM::new(wcm.clone(), std.clone()); + + let mem_proxy = Self { + registered_predecessors: AtomicU32::new(0), + mem_align_sm, + mem_align_rom_sm, + mem_sm, + input_data_sm, + rom_data_sm, + }; + let mem_proxy = Arc::new(mem_proxy); + + wcm.register_component(mem_proxy.clone(), None, None); + + // For all the secondary state machines, register the main state machine as a predecessor + mem_proxy.mem_align_rom_sm.register_predecessor(); + mem_proxy.mem_align_sm.register_predecessor(); + mem_proxy.mem_sm.register_predecessor(); + mem_proxy.input_data_sm.register_predecessor(); + mem_proxy.rom_data_sm.register_predecessor(); + mem_proxy + } + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.mem_align_rom_sm.unregister_predecessor(); + self.mem_align_sm.unregister_predecessor(); + self.mem_sm.unregister_predecessor(); + self.input_data_sm.unregister_predecessor(); + self.rom_data_sm.unregister_predecessor(); + } + } + + pub fn prove( + &self, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + let mut engine = MemProxyEngine::::new(self.mem_align_sm.clone()); + engine.add_module("mem", self.mem_sm.clone()); + engine.add_module("input_data", self.input_data_sm.clone()); + engine.add_module("row_data", self.rom_data_sm.clone()); + engine.prove(mem_operations) + } +} + +impl WitnessComponent for MemProxy {} diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs new file mode 100644 index 00000000..3ca6bf5d --- /dev/null +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -0,0 +1,628 @@ +//! The `MemProxyEngine` module is designed to facilitate dividing the proxy logic into smaller, +//! more manageable pieces of code. +//! +//! The engine is created through MemProxy on a static call, which creates the `MemProxyEngine`. +//! `MemProxyEngine` has state, and this state allows the implementation of smaller, focused +//! methods, making the codebase easier to maintain and extend. +//! +//! +//! ## Creation and Setup of the `MemProxyEngine` +//! +//! When creating the `MemProxyEngine`, a state machine is provided to handle alignment of memory +//! accesses. This state machine is responsible for demostrate unaligned accesses based on aligned +//! ones. +//! +//! Once the `MemProxyEngine` is created, all memory modules are registered. These modules must +//! implement the `MemModule` trait, which serves three purposes: +//! +//! 1. To define the range of addresses (regions) they are responsible for handling. +//! 2. To specify the frequency (number of inputs) at which they expect to receive inputs. +//! 3. To define the "callback" used to send inputs to the module +//! +//! +//! ## Inputs from `MemProxyEngine` +//! +//! The inputs to the `MemProxyEngine` are represented as an enumeration to optimize memory usage +//! and performance. This design ensures efficient handling of both common and rare cases, +//! balancing memory allocation and computational efficiency. +//! +//! The enumeration has two variants: +//! 1. `Basic`: The primary input type, used for the majority of memory accesses. This variant is +//! highly optimized to minimize overhead and ensure efficient processing in typical scenarios. +//! 2. `Extended`: A specialized input type used exclusively for handling unaligned memory +//! accesses. This variant is appended to the vector immediately after the corresponding `Basic` +//! instance that generates it. The `Extended` input contains the aligned memory values required +//! to process the unaligned access (in word case two values) +//! +//! By adopting this design, the `MemProxyEngine` avoids penalizing the commonly used `Basic` type +//! due to the less frequent unaligned cases that requires addicional `Extended` type. This +//! separation ensures that unaligned access handling introduces minimal overhead to the overall +//! system, while still providing the flexibility to unaligned access. +//! +//! +//! ## Logic of the `MemProxyEngine` +//! +//! Step 1. Sort the aligned memory accesses +//! original vector is sorted by step, sort_by_key is stable, no reordering of elements with +//! the same key. +//! +//! Step 2. Add a final mark mem_op to force flush of open_mem_align_ops, because always the +//! last operation is mem_op. +//! +//! Step 3. Composing information for memory operation (access). In this step, all necessary +//! information is gathered and composed to perform a memory operation. The process involves +//! reading the next input from the input vector, which defines the nature of the operation. +//! +//! - For standard (aligned) operations, only the `Basic` input is required, and the operation +//! proceeds directly. +//! - For unaligned operations, the `Extended` input is also read. This additional input provides +//! the extra values required to handle the unaligned operation. +//! +//! Step 4. Process each memory operation ordered by address and step. When a non-aligned +//! memory access there are two possible situations: +//! +//! 1. The operation applies only applies to one memory address (read or read+write). In this case +//! mem_align helper return the aligned operation for this address, and loop continues. +//! +//! 2. The operation applies to two consecutive memory addresses, mem_align helper returns the +//! aligned operation involved for the current address, and the second part of the operation is +//! enqueued to open_mem_align_ops, it will processed when processing next address. +//! +//! First, we verify if there are any "previous" open memory alignment operations +//! (`open_mem_align_ops`) that need to be processed before handling the current `mem_op`. If such +//! operations exist, they are processed first, and then the current `mem_op` is executed. +//! +//! At the end of Step 2, a final marker is used to ensure a forced flush of any remaining +//! `open_mem_align_ops`. This guarantees that all pending alignment operations are completed, +//! as the last operation in this step is always a `mem_op`. +//! +//! +//! ## Handling Large Gaps Between Steps +//! +//! One challenge in the design is addressing cases where the distance between steps becomes +//! more large than max range check MEMORY_MAX_DIFF (current 2^24). This solve this situation +//! the proxy add extra intermediate internal reads (internal because don't send to bus), each +//! increase step in MEMORY_MAX_DIFF to arrive to the final step. + +use std::{collections::VecDeque, sync::Arc}; + +use crate::{ + MemAlignInput, MemAlignResponse, MemAlignSM, MemHelpers, MemInput, MemModule, MemUnmapped, + MAX_MAIN_STEP, MAX_MEM_ADDR, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_STEP, MAX_MEM_STEP_OFFSET, + MEMORY_MAX_DIFF, MEM_ADDR_MASK, MEM_BYTES, MEM_BYTES_BITS, +}; +use log::info; + +use p3_field::PrimeField; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; +use zisk_core::ZiskRequiredMemory; + +#[cfg(feature = "debug_mem_proxy_engine")] +const DEBUG_ADDR: u32 = 0x90000008; + +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_proxy_engine")] + { + info!(concat!("MemProxy: ",$prefix), $($arg)*); + } + }; +} + +struct MemModuleData { + pub name: String, + pub inputs: Vec, + pub flush_input_size: usize, +} + +#[derive(Debug)] +pub struct AddressRegion { + from_addr: u32, + to_addr: u32, + module_id: u8, +} +pub struct MemProxyEngine { + modules: Vec>>, + modules_data: Vec, + open_mem_align_ops: VecDeque, + addr_map: Vec, + addr_map_fetched: bool, + current_module_id: usize, + current_module: String, + module_end_addr: u32, + mem_align_sm: Arc>, + next_open_addr: u32, + next_open_step: u64, + last_addr: u32, + last_step: u64, + intermediate_cases: u32, + intermediate_steps: u32, +} + +const NO_OPEN_ADDR: u32 = 0xFFFF_FFFF; +const NO_OPEN_STEP: u64 = 0xFFFF_FFFF_FFFF_FFFF; + +impl MemProxyEngine { + pub fn new(mem_align_sm: Arc>) -> Self { + Self { + modules: Vec::new(), + modules_data: Vec::new(), + current_module_id: 0, + current_module: String::new(), + module_end_addr: 0, + open_mem_align_ops: VecDeque::new(), + addr_map: Vec::new(), + addr_map_fetched: false, + mem_align_sm, + next_open_addr: NO_OPEN_ADDR, + next_open_step: NO_OPEN_STEP, + last_addr: 0xFFFF_FFFF, + last_step: 0, + intermediate_cases: 0, + intermediate_steps: 0, + } + } + + pub fn add_module(&mut self, name: &str, module: Arc>) { + if self.modules.is_empty() { + self.current_module = String::from(name); + } + let module_id = self.modules.len() as u8; + self.modules.push(module.clone()); + + let ranges = module.get_addr_ranges(); + let flush_input_size = module.get_flush_input_size(); + + for range in ranges.iter() { + debug_info!("adding range 0x{:X} 0x{:X} to {}", range.0, range.1, name); + self.insert_address_range(range.0, range.1, module_id); + } + self.modules_data.push(MemModuleData { + name: String::from(name), + inputs: Vec::new(), + flush_input_size: if flush_input_size == 0 { + 0xFFFF_FFFF_FFFF_FFFF + } else { + flush_input_size as usize + }, + }); + } + /* insert in sort way the address map and verify that */ + fn insert_address_range(&mut self, from_addr: u32, to_addr: u32, module_id: u8) { + let region = AddressRegion { from_addr, to_addr, module_id }; + if let Some(index) = self.addr_map.iter().position(|x| x.from_addr >= from_addr) { + self.addr_map.insert(index, region); + } else { + self.addr_map.push(region); + } + } + + pub fn prove( + &mut self, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + self.init_prove(); + + // Sort the aligned memory accesses + // original vector is sorted by step, sort_by_key is stable, no reordering of elements with + // the same key. + + timer_start_debug!(MEM_SORT); + mem_operations.sort_by_key(|mem| (mem.get_address() & 0xFFFF_FFF8)); + timer_stop_and_log_debug!(MEM_SORT); + + // Add a final mark mem_op to force flush of open_mem_align_ops, because always the + // last operation is mem_op. + + self.push_end_of_memory_mark(mem_operations); + + let mut index = 0; + let count = mem_operations.len(); + while index < count { + if let ZiskRequiredMemory::Basic { + step, + value, + address, + is_write, + width, + step_offset, + } = mem_operations[index] + { + let extend_values = if !Self::is_aligned(address, width) { + debug_assert!(index + 1 < count, "expected one element more extended !!"); + if let ZiskRequiredMemory::Extended { address: _, values } = + mem_operations[index + 1] + { + index += 1; + values + } else { + panic!("MemProxy::prove() unexpected Basic variant"); + } + } else { + [0, 0] + }; + index += 1; + if !self.prove_one( + address, + MemHelpers::main_step_to_address_step(step, step_offset), + value, + is_write, + width, + extend_values, + ) { + break; + } + } else { + panic!("MemProxy::prove() unexpected Extended variant"); + } + } + self.finish_prove(); + Ok(()) + } + + fn prove_one( + &mut self, + addr: u32, + mem_step: u64, + value: u64, + is_write: bool, + width: u8, + extend_values: [u64; 2], + ) -> bool { + let is_aligned: bool = Self::is_aligned(addr, width); + let aligned_mem_addr = Self::to_aligned_addr(addr); + + // Check if there are open mem align operations to be processed in this moment, + // with address (or step) less than the aligned of current + // mem_op. + self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step); + + // check if we are at end of loop + if self.check_if_end_of_memory_mark(addr, mem_step) { + return false; + } + + // all open mem align operations are processed, check if new mem operation is + // aligned + if !is_aligned { + // In this point found non-aligned memory access, phase-0 + let mem_align_input = MemAlignInput { + addr, + value, + width, + mem_values: extend_values, + is_write, + step: mem_step, + }; + let mem_align_response = self.mem_align_sm.get_mem_op(&mem_align_input, 0); + + #[cfg(feature = "debug_mem_proxy_engine")] + Self::debug_mem_align_api(&mem_align_input, &mem_align_response, 0); + + // if operation applies to two consecutive memory addresses, add the second + // part is enqueued to be processed in future when + // processing next address on phase-1 + self.push_mem_align_response_ops( + aligned_mem_addr, + extend_values[0], + &mem_align_input, + &mem_align_response, + ); + if mem_align_response.more_addr { + self.open_mem_align_ops.push_back(mem_align_input); + self.update_next_open_mem_align(); + } + } else { + self.push_aligned_op(is_write, addr, value, mem_step); + } + true + } + + fn update_next_open_mem_align(&mut self) { + if self.open_mem_align_ops.is_empty() { + self.next_open_addr = NO_OPEN_ADDR; + self.next_open_step = NO_OPEN_STEP; + } else if self.open_mem_align_ops.len() == 1 { + let mem_align_input = self.open_mem_align_ops.front().unwrap(); + self.next_open_addr = Self::next_aligned_addr(mem_align_input.addr); + self.next_open_step = mem_align_input.step; + } + } + + fn process_all_previous_open_mem_align_ops(&mut self, mem_addr: u32, mem_step: u64) { + // Two possible situations to process open mem align operations: + // + // 1) the address of open operation is less than the aligned address. + // 2) the address of open operation is equal to the aligned address, but the step of the + // open operation is less than the step of the current operation. + + while let Some(open_op) = self.get_next_open_mem_align_input(mem_addr, mem_step) { + // call to mem_align to get information of the aligned memory access needed + // to prove the unaligned open operation. + let mem_align_resp = self.mem_align_sm.get_mem_op(&open_op, 1); + + #[cfg(feature = "debug_mem_proxy_engine")] + Self::debug_mem_align_api(&open_op, &mem_align_resp, 1); + + // push the aligned memory operations for current address (read or read+write) and + // update last_address and last_value. + self.push_mem_align_response_ops( + Self::next_aligned_addr(open_op.addr), + open_op.mem_values[1], + &open_op, + &mem_align_resp, + ); + } + } + + pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_BY_MAIN_STEP * step + 2 * step_offset as u64 + } + + #[inline(always)] + fn is_aligned(address: u32, width: u8) -> bool { + ((address & 0x07) == 0) && (width == 8) + } + + fn push_aligned_op(&mut self, is_write: bool, addr: u32, value: u64, step: u64) { + self.update_mem_module(addr); + let w_addr = Self::to_aligned_word_addr(addr); + + // check if step difference is too large + if self.last_addr == w_addr && (step - self.last_step) > MEMORY_MAX_DIFF { + self.push_intermediate_internal_reads(w_addr, value, self.last_step, step); + } + + self.last_step = step; + self.last_addr = w_addr; + + let mem_op = MemInput { step, is_write, is_internal: false, addr: w_addr, value }; + debug_info!( + "route ==> {}[{:X}] {} {} #{}", + self.current_module, + mem_op.addr << MEM_BYTES_BITS, + if is_write { "W" } else { "R" }, + value, + step, + ); + self.internal_push_mem_op(mem_op); + } + + fn push_intermediate_internal_reads( + &mut self, + addr: u32, + value: u64, + last_step: u64, + final_step: u64, + ) { + let mut step = last_step; + self.intermediate_cases += 1; + while (final_step - step) > MEMORY_MAX_DIFF { + self.intermediate_steps += 1; + step += MEMORY_MAX_DIFF; + let mem_op = MemInput { step, is_write: false, is_internal: true, addr, value }; + self.internal_push_mem_op(mem_op); + } + } + + fn internal_push_mem_op(&mut self, mem_op: MemInput) { + self.modules_data[self.current_module_id].inputs.push(mem_op); + self.check_flush_inputs(); + } + // method to add aligned read operation + #[inline(always)] + fn push_aligned_read(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(false, addr, value, step); + } + // method to add aligned write operation + #[inline(always)] + fn push_aligned_write(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(true, addr, value, step); + } + /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible + /// situations: + /// 1) read, only on single mem_op is pushed + /// 2) read+write, two mem_op are pushed, one read and one write. + /// + /// This process is used for each aligned memory address, means that the "second part" of non + /// aligned memory operation is processed on addr + MEM_BYTES. + fn push_mem_align_response_ops( + &mut self, + mem_addr: u32, + mem_value: u64, + mem_align_input: &MemAlignInput, + mem_align_resp: &MemAlignResponse, + ) { + self.push_aligned_read(mem_addr, mem_value, mem_align_resp.step); + if mem_align_input.is_write { + self.push_aligned_write( + mem_addr, + mem_align_resp.value.unwrap(), + mem_align_resp.step + 1, + ); + } + } + fn set_active_region(&mut self, region_id: usize) { + self.current_module_id = self.addr_map[region_id].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.addr_map[region_id].to_addr; + } + fn update_mem_module_id(&mut self, addr: u32) { + debug_info!("search module for address 0x{:X}", addr); + if let Some(index) = + self.addr_map.iter().position(|x| x.from_addr <= addr && x.to_addr >= addr) + { + self.set_active_region(index); + } else { + panic!("out-of-memory 0x{:X}", addr); + } + } + fn update_mem_module(&mut self, addr: u32) { + // check if need to reevaluate the module id + if addr > self.module_end_addr { + self.update_mem_module_id(addr); + } + } + fn check_flush_inputs(&mut self) { + // check if need to flush the inputs of the module + let mid = self.current_module_id; + let inputs = self.modules_data[mid].inputs.len(); + if inputs >= self.modules_data[mid].flush_input_size { + // TODO: optimize passing ownership of inputs to module, and creating a new input + // object + debug_info!("flush {} inputs => {}", inputs, self.current_module); + self.modules[mid].send_inputs(&self.modules_data[mid].inputs); + self.modules_data[mid].inputs.clear(); + } + } + + fn get_next_open_mem_align_input(&mut self, addr: u32, step: u64) -> Option { + if self.next_open_addr < addr || (self.next_open_addr == addr && self.next_open_step < step) + { + let open_op = self.open_mem_align_ops.pop_front().unwrap(); + self.update_next_open_mem_align(); + Some(open_op) + } else { + None + } + } + // method to process open mem align operations, second part of non aligned memory operations + // applies to two consecutive memory addresses. + + fn push_end_of_memory_mark(&mut self, mem_operations: &mut Vec) { + mem_operations.push(ZiskRequiredMemory::Basic { + step: MAX_MAIN_STEP, + step_offset: MAX_MEM_STEP_OFFSET as u8, + is_write: false, + address: MAX_MEM_ADDR as u32, + width: MEM_BYTES as u8, + value: 0, + }); + mem_operations + .push(ZiskRequiredMemory::Extended { address: MAX_MEM_ADDR as u32, values: [0, 0] }); + } + + /// Check if the address is the "special" address inserted at the end of the memory operations + #[inline(always)] + fn check_if_end_of_memory_mark(&self, addr: u32, mem_step: u64) -> bool { + if addr == MAX_MEM_ADDR as u32 && mem_step == MAX_MEM_STEP { + debug_assert!( + self.open_mem_align_ops.is_empty(), + "open_mem_align_ops not empty, has {} elements", + self.open_mem_align_ops.len() + ); + true + } else { + false + } + } + /// Encapsulates all tasks to be performed at the beginning of the witness computation (stage + /// 1). + /// + /// This method fetches the address map and sets the initial values to prepare for the + /// computation. + fn init_prove(&mut self) { + if !self.addr_map_fetched { + self.fetch_address_map(); + } + self.current_module_id = self.addr_map[0].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.addr_map[0].to_addr; + } + /// Encapsulates all tasks to be performed at the end of the witness computation (stage 1). + /// + /// This method flushes all module inputs to ensure they are finalized and ready for further + /// processing. + fn finish_prove(&self) { + for (module_id, module) in self.modules.iter().enumerate() { + debug_info!( + "{}: flush all({}) inputs", + self.modules_data[module_id].name, + self.modules_data[module_id].inputs.len() + ); + module.send_inputs(&self.modules_data[module_id].inputs); + } + info!( + "MemProxy: ··· Intermediate reads [cases:{} steps:{}]", + self.intermediate_cases, self.intermediate_steps + ); + } + /// Fetches the address map, defining and calculating all necessary structures to manage the + /// memory map. + /// + /// For undefined regions (such as memory between defined regions, or memory at the beginning or + /// end of the memory map), this method assigns an unmapped module. If any access occurs + /// within these unmapped memory regions, the method will trigger a panic. + /// + /// The unmapped module ensures that every address has an associated module to handle memory + /// access, providing a safety mechanism to prevent undefined behavior. + fn fetch_address_map(&mut self) { + let unmapped_regions: Vec<(u32, u32)> = self.get_unmapped_regions(); + if !unmapped_regions.is_empty() { + self.define_unmapped_module(&unmapped_regions); + } + self.addr_map_fetched = true; + } + + /// Get list of regions (from_addr, to_addr) that are not defined in the memory map + fn get_unmapped_regions(&self) -> Vec<(u32, u32)> { + let mut next_addr = 0; + let mut unmapped_regions: Vec<(u32, u32)> = Vec::new(); + for addr_region in self.addr_map.iter() { + if next_addr < addr_region.from_addr { + unmapped_regions.push((next_addr, addr_region.from_addr - 1)); + } + next_addr = addr_region.to_addr + 1; + } + unmapped_regions + } + + /// Define an unmapped module with all unmapped regions. + fn define_unmapped_module(&mut self, unmapped_regions: &[(u32, u32)]) { + let mut unmapped_module = MemUnmapped::::new(); + for unmapped_region in unmapped_regions.iter() { + unmapped_module.add_range(unmapped_region.0, unmapped_region.1); + } + self.add_module("unmapped", Arc::new(unmapped_module)); + } + + /// Calculate aligned address from regular address (aligned or not) + #[inline(always)] + fn to_aligned_addr(addr: u32) -> u32 { + addr & MEM_ADDR_MASK + } + + /// Calculate the next aligned address from regular address (aligned or not) + #[inline(always)] + fn next_aligned_addr(addr: u32) -> u32 { + (addr & MEM_ADDR_MASK) + MEM_BYTES + } + + /// Calculate the word address where word is MEM_BYTES + #[inline(always)] + fn to_aligned_word_addr(addr: u32) -> u32 { + addr >> MEM_BYTES_BITS + } + + #[cfg(feature = "debug_mem_proxy_engine")] + fn debug_mem_align_api( + mem_align_input: &MemAlignInput, + mem_align_response: &MemAlignResponse, + phase: u8, + ) { + if mem_align_input.addr >= DEBUG_ADDR - 8 && mem_align_input.addr <= DEBUG_ADDR + 8 { + debug_info!( + "mem_align_input_{:X}: phase:{} {:?}", + mem_align_input.addr, + phase, + mem_align_input + ); + debug_info!( + "mem_align_response_{:X}: phase:{} {:?}", + mem_align_input.addr, + phase, + mem_align_response + ); + } + } +} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs new file mode 100644 index 00000000..051277e6 --- /dev/null +++ b/state-machines/mem/src/mem_sm.rs @@ -0,0 +1,383 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{MemInput, MemModule, MEMORY_MAX_DIFF, MEM_BYTES_BITS}; +use num_bigint::BigInt; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use zisk_core::{RAM_ADDR, RAM_SIZE}; +use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; + +const RAM_W_ADDR_INIT: u32 = RAM_ADDR as u32 >> MEM_BYTES_BITS; +const RAM_W_ADDR_END: u32 = (RAM_ADDR + RAM_SIZE - 1) as u32 >> MEM_BYTES_BITS; + +const _: () = { + assert!((RAM_SIZE - 1) >> MEM_BYTES_BITS <= MEMORY_MAX_DIFF, "RAM is too large"); + assert!( + (RAM_ADDR + RAM_SIZE - 1) <= 0xFFFF_FFFF, + "RAM memory exceeds the 32-bit addressable range" + ); +}; + +pub struct MemSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, +} + +#[derive(Default)] +pub struct MemAirValues { + pub segment_id: u32, + pub is_first_segment: bool, + pub is_last_segment: bool, + pub previous_segment_addr: u32, + pub previous_segment_step: u64, + pub previous_segment_value: [u32; 2], + pub segment_last_addr: u32, + pub segment_last_step: u64, + pub segment_last_value: [u32; 2], +} +#[derive(Debug)] +pub struct MemPreviousSegment { + pub addr: u32, + pub step: u64, + pub value: u64, +} + +#[allow(unused, unused_variables)] +impl MemSM { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let mem_sm = + Self { wcm: wcm.clone(), std: std.clone(), registered_predecessors: AtomicU32::new(0) }; + let mem_sm = Arc::new(mem_sm); + + wcm.register_component(mem_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MEM_AIR_IDS)); + std.register_predecessor(); + + mem_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } + } + + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs + // memory only need to process these special inputs, but inputs no change. At end of + // inputs proxy add an extra internal input to jump to last address + + let air_id = MEM_AIR_IDS[0]; + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, air_id); + let air_rows = air.num_rows(); + + // at least one row to go + let count = inputs.len(); + let count_rem = count % air_rows; + let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; + + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_segments]); + let mut global_idxs = vec![0; num_segments]; + + #[allow(clippy::needless_range_loop)] + for i in 0..num_segments { + // TODO: Review + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1) + { + let trace: MemTrace<'_, _> = MemTrace::new(air_rows); + let mut buffer = trace.buffer.unwrap(); + + prover_buffers.lock().unwrap()[i] = buffer; + global_idxs[i] = global_idx; + } + } + + #[allow(clippy::needless_range_loop)] + for segment_id in 0..num_segments { + let is_last_segment = segment_id == num_segments - 1; + let input_offset = segment_id * air_rows; + let previous_segment = if (segment_id == 0) { + MemPreviousSegment { addr: RAM_W_ADDR_INIT, step: 0, value: 0 } + } else { + MemPreviousSegment { + addr: inputs[input_offset - 1].addr, + step: inputs[input_offset - 1].step, + value: inputs[input_offset - 1].value, + } + }; + let input_end = + if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; + let mem_ops = &inputs[input_offset..input_end]; + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + &previous_segment, + prover_buffer, + air_rows, + global_idxs[segment_id], + ); + } + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `MemoryInput` inputs + #[allow(clippy::too_many_arguments)] + pub fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + previous_segment: &MemPreviousSegment, + mut prover_buffer: Vec, + air_mem_rows: usize, + global_idx: usize, + ) -> Result<(), Box> { + assert!( + !mem_ops.is_empty() && mem_ops.len() <= air_mem_rows, + "MemSM: mem_ops.len()={} out of range {}", + mem_ops.len(), + air_mem_rows + ); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + let mut trace = MemTrace::::map_buffer(&mut prover_buffer, air_mem_rows, 0).unwrap(); + + let mut range_check_data: Vec = vec![0; MEMORY_MAX_DIFF as usize]; + + let mut air_values = MemAirValues { + segment_id: segment_id as u32, + is_first_segment: segment_id == 0, + is_last_segment, + previous_segment_addr: previous_segment.addr, + previous_segment_step: previous_segment.step, + previous_segment_value: [ + previous_segment.value as u32, + (previous_segment.value >> 32) as u32, + ], + ..MemAirValues::default() + }; + + // index it's value - 1, for this reason no add +1 + range_check_data[(previous_segment.addr - RAM_W_ADDR_INIT) as usize] += 1; // TODO + + // Fill the remaining rows + let mut last_addr: u32 = previous_segment.addr; + let mut last_step: u64 = previous_segment.step; + let mut last_value: u64 = previous_segment.value; + + for (i, mem_op) in mem_ops.iter().enumerate() { + trace[i].addr = F::from_canonical_u32(mem_op.addr); + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::from_bool(!mem_op.is_internal); + trace[i].wr = F::from_bool(mem_op.is_write); + + let (low_val, high_val) = self.get_u32_values(mem_op.value); + trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + let addr_changes = last_addr != mem_op.addr; + trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; + + let increment = if addr_changes { + // (mem_op.addr - last_addr + if i == 0 && segment_id == 0 { 1 } else { 0 }) as u64 + (mem_op.addr - last_addr) as u64 + } else { + mem_op.step - last_step + }; + trace[i].increment = F::from_canonical_u64(increment); + + // Store the value of incremenet so it can be range checked + if increment <= MEMORY_MAX_DIFF || increment == 0 { + range_check_data[(increment - 1) as usize] += 1; + } else { + panic!("MemSM: increment's out of range: {} i:{} addr_changes:{} mem_op.addr:0x{:X} last_addr:0x{:X} mem_op.step:{} last_step:{}", + increment, i, addr_changes as u8, mem_op.addr, last_addr, mem_op.step, last_step); + } + + last_addr = mem_op.addr; + last_step = mem_op.step; + last_value = mem_op.value; + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd + // = 1, wr = 0 + let last_row_idx = mem_ops.len() - 1; + let addr = trace[last_row_idx].addr; + let value = trace[last_row_idx].value; + + let padding_size = air_mem_rows - mem_ops.len(); + for i in mem_ops.len()..air_mem_rows { + last_step += 1; + trace[i].addr = addr; + trace[i].step = F::from_canonical_u64(last_step); + trace[i].sel = F::zero(); + trace[i].wr = F::zero(); + + trace[i].value = value; + + trace[i].addr_changes = F::zero(); + trace[i].increment = F::one(); + } + + air_values.segment_last_addr = last_addr; + air_values.segment_last_step = last_step; + air_values.segment_last_value[0] = last_value as u32; + air_values.segment_last_value[1] = (last_value >> 32) as u32; + + // Store the value of trivial increment so that they can be range checked + // value = 1 => index = 0 + range_check_data[0] += padding_size as u64; + + // no add extra +1 because index = value - 1 + // RAM_W_ADDR_END - last_addr + 1 - 1 = RAM_W_ADDR_END - last_addr + range_check_data[(RAM_W_ADDR_END - last_addr) as usize] += 1; // TODO + + // TODO: Perform the range checks + let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + for (value, &multiplicity) in range_check_data.iter().enumerate() { + if (multiplicity == 0) { + continue; + } + self.std.range_check( + F::from_canonical_usize(value + 1), + F::from_canonical_u64(multiplicity), + range_id, + ); + } + + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + let mut air_instance = AirInstance::new( + sctx.clone(), + ZISK_AIRGROUP_ID, + MEM_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + self.set_airvalues("Mem", &mut air_instance, &air_values); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + + Ok(()) + } + + fn get_u32_values(&self, value: u64) -> (u32, u32) { + (value as u32, (value >> 32) as u32) + } + fn set_airvalues( + &self, + prefix: &str, + air_instance: &mut AirInstance, + air_values: &MemAirValues, + ) { + air_instance.set_airvalue( + format!("{}.segment_id", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_id), + ); + air_instance.set_airvalue( + format!("{}.is_first_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_first_segment), + ); + air_instance.set_airvalue( + format!("{}.is_last_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_last_segment), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.previous_segment_addr), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.previous_segment_step), + ); + air_instance.set_airvalue( + format!("{}.segment_last_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_last_addr), + ); + air_instance.set_airvalue( + format!("{}.segment_last_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.segment_last_step), + ); + let count = air_values.previous_segment_value.len(); + for i in 0..count { + air_instance.set_airvalue( + format!("{}.previous_segment_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.previous_segment_value[i]), + ); + air_instance.set_airvalue( + format!("{}.segment_last_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.segment_last_value[i]), + ); + } + } +} + +impl MemModule for MemSM { + fn send_inputs(&self, mem_op: &[MemInput]) { + self.prove(mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(RAM_ADDR as u32, (RAM_ADDR + RAM_SIZE - 1) as u32)] + } + fn get_flush_input_size(&self) -> u32 { + 0 + } +} + +impl WitnessComponent for MemSM {} diff --git a/state-machines/mem/src/mem_traces.rs b/state-machines/mem/src/mem_traces.rs deleted file mode 100644 index c80a8c74..00000000 --- a/state-machines/mem/src/mem_traces.rs +++ /dev/null @@ -1,5 +0,0 @@ -use proofman_common as common; -pub use proofman_macros::trace; - -trace!(MemALignedRow, MemALignedTrace { fake: F }); -trace!(MemUnaLignedRow, MemUnaLignedTrace { fake: F}); diff --git a/state-machines/mem/src/mem_unaligned.rs b/state-machines/mem/src/mem_unaligned.rs deleted file mode 100644 index fde238e3..00000000 --- a/state-machines/mem/src/mem_unaligned.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemUnalignedOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct MemUnalignedSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -#[allow(unused, unused_variables)] -impl MemUnalignedSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); - - wcm.register_component( - mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_UNALIGNED_AIR_IDS), - ); - - mem_aligned_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - fn read( - &self, - _addr: u64, - _width: usize, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } - - fn write( - &self, - _addr: u64, - _width: usize, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } -} - -impl WitnessComponent for MemUnalignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemUnalignedSM { - fn calculate(&self, operation: MemUnalignedOp) -> Result> { - match operation { - MemUnalignedOp::Read(addr, width) => self.read(addr, width), - MemUnalignedOp::Write(addr, width, val) => self.write(addr, width, val), - } - } - - fn prove(&self, operations: &[MemUnalignedOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: MemUnalignedOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/mem/src/mem_unmapped.rs b/state-machines/mem/src/mem_unmapped.rs new file mode 100644 index 00000000..2ec61685 --- /dev/null +++ b/state-machines/mem/src/mem_unmapped.rs @@ -0,0 +1,35 @@ +use std::marker::PhantomData; + +use crate::{MemInput, MemModule}; +use p3_field::PrimeField; + +pub struct MemUnmapped { + ranges: Vec<(u32, u32)>, + __data: PhantomData, +} + +impl Default for MemUnmapped { + fn default() -> Self { + Self::new() + } +} + +impl MemUnmapped { + pub fn new() -> Self { + Self { ranges: Vec::new(), __data: PhantomData } + } + pub fn add_range(&mut self, _start: u32, _end: u32) { + self.ranges.push((_start, _end)); + } +} +impl MemModule for MemUnmapped { + fn send_inputs(&self, _mem_op: &[MemInput]) { + // panic!("[MemUnmapped] invalid access to addr {:x}", _mem_op[0].addr); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + self.ranges.to_vec() + } + fn get_flush_input_size(&self) -> u32 { + 1 + } +} diff --git a/state-machines/mem/src/rom_data.rs b/state-machines/mem/src/rom_data.rs new file mode 100644 index 00000000..57243430 --- /dev/null +++ b/state-machines/mem/src/rom_data.rs @@ -0,0 +1,339 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{ + MemAirValues, MemInput, MemModule, MemPreviousSegment, MEMORY_MAX_DIFF, MEM_BYTES_BITS, +}; +use num_bigint::BigInt; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; +use zisk_core::{ROM_ADDR, ROM_ADDR_MAX}; +use zisk_pil::{RomDataTrace, ROM_DATA_AIR_IDS, ZISK_AIRGROUP_ID}; + +const ROM_W_ADDR: u32 = ROM_ADDR as u32 >> MEM_BYTES_BITS; +const ROM_W_ADDR_END: u32 = ROM_ADDR_MAX as u32 >> MEM_BYTES_BITS; + +const _: () = { + assert!( + (ROM_ADDR_MAX - ROM_ADDR) >> MEM_BYTES_BITS as u64 <= MEMORY_MAX_DIFF, + "ROM_DATA is too large" + ); + assert!(ROM_ADDR_MAX <= 0xFFFF_FFFF, "ROM_DATA memory exceeds the 32-bit addressable range"); +}; + +pub struct RomDataSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + num_rows: usize, + // Count of registered predecessors + registered_predecessors: AtomicU32, +} + +#[allow(unused, unused_variables)] +impl RomDataSM { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, ROM_DATA_AIR_IDS[0]); + let rom_data_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + num_rows: air.num_rows(), + registered_predecessors: AtomicU32::new(0), + }; + let rom_data_sm = Arc::new(rom_data_sm); + + wcm.register_component(rom_data_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(ROM_DATA_AIR_IDS)); + std.register_predecessor(); + + rom_data_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } + } + + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs + // memory only need to process these special inputs, but inputs no change. At end of + // inputs proxy add an extra internal input to jump to last address + + let air_id = ROM_DATA_AIR_IDS[0]; + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, air_id); + let air_rows = air.num_rows(); + + // at least one row to go + let count = inputs.len(); + let count_rem = count % air_rows; + let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; + + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_segments]); + let mut global_idxs = vec![0; num_segments]; + + #[allow(clippy::needless_range_loop)] + for i in 0..num_segments { + // TODO: Review + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1) + { + let trace: RomDataTrace<'_, _> = RomDataTrace::new(air_rows); + let mut buffer = trace.buffer.unwrap(); + prover_buffers.lock().unwrap()[i] = buffer; + global_idxs[i] = global_idx; + } + } + + #[allow(clippy::needless_range_loop)] + for segment_id in 0..num_segments { + let is_last_segment = segment_id == num_segments - 1; + let input_offset = segment_id * air_rows; + let previous_segment = if (segment_id == 0) { + MemPreviousSegment { addr: ROM_W_ADDR, step: 0, value: 0 } + } else { + MemPreviousSegment { + addr: inputs[input_offset - 1].addr, + step: inputs[input_offset - 1].step, + value: inputs[input_offset - 1].value, + } + }; + let input_end = + if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; + let mem_ops = &inputs[input_offset..input_end]; + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + &previous_segment, + prover_buffer, + air_rows, + global_idxs[segment_id], + ); + } + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `MemoryInput` inputs + #[allow(clippy::too_many_arguments)] + pub fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + previous_segment: &MemPreviousSegment, + mut prover_buffer: Vec, + air_mem_rows: usize, + global_idx: usize, + ) -> Result<(), Box> { + assert!( + !mem_ops.is_empty() && mem_ops.len() <= air_mem_rows, + "RomDataSM: mem_ops.len()={} out of range {}", + mem_ops.len(), + air_mem_rows + ); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + let mut trace = RomDataTrace::::map_buffer(&mut prover_buffer, air_mem_rows, 0).unwrap(); + + let mut air_values = MemAirValues { + segment_id: segment_id as u32, + is_first_segment: segment_id == 0, + is_last_segment, + previous_segment_addr: previous_segment.addr, + previous_segment_step: previous_segment.step, + previous_segment_value: [ + previous_segment.value as u32, + (previous_segment.value >> 32) as u32, + ], + ..MemAirValues::default() + }; + + // range of instance + let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + self.std.range_check( + F::from_canonical_u32(previous_segment.addr - ROM_W_ADDR + 1), + F::one(), + range_id, + ); + + // Fill the remaining rows + let mut last_addr: u32 = previous_segment.addr; + let mut last_step: u64 = previous_segment.step; + let mut last_value: u64 = previous_segment.value; + + for (i, mem_op) in mem_ops.iter().enumerate() { + trace[i].addr = F::from_canonical_u32(mem_op.addr); + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::from_bool(!mem_op.is_internal); + + let (low_val, high_val) = self.get_u32_values(mem_op.value); + trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + let addr_changes = last_addr != mem_op.addr; + trace[i].addr_changes = + if addr_changes || (i == 0 && segment_id == 0) { F::one() } else { F::zero() }; + + last_addr = mem_op.addr; + last_step = mem_op.step; + last_value = mem_op.value; + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd + // = 1, wr = 0 + let last_row_idx = mem_ops.len() - 1; + let addr = trace[last_row_idx].addr; + let value = trace[last_row_idx].value; + + let padding_size = air_mem_rows - mem_ops.len(); + for i in mem_ops.len()..air_mem_rows { + last_step += 1; + trace[i].addr = addr; + trace[i].step = F::from_canonical_u64(last_step); + trace[i].sel = F::zero(); + + trace[i].value = value; + + trace[i].addr_changes = F::zero(); + } + + air_values.segment_last_addr = last_addr; + air_values.segment_last_step = last_step; + air_values.segment_last_value[0] = last_value as u32; + air_values.segment_last_value[1] = (last_value >> 32) as u32; + + self.std.range_check( + F::from_canonical_u32(ROM_W_ADDR_END - last_addr + 1), + F::one(), + range_id, + ); + + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + let mut air_instance = AirInstance::new( + sctx.clone(), + ZISK_AIRGROUP_ID, + ROM_DATA_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + self.set_airvalues("RomData", &mut air_instance, &air_values); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + + Ok(()) + } + + fn get_u32_values(&self, value: u64) -> (u32, u32) { + (value as u32, (value >> 32) as u32) + } + fn set_airvalues( + &self, + prefix: &str, + air_instance: &mut AirInstance, + air_values: &MemAirValues, + ) { + air_instance.set_airvalue( + format!("{}.segment_id", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_id), + ); + air_instance.set_airvalue( + format!("{}.is_first_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_first_segment), + ); + air_instance.set_airvalue( + format!("{}.is_last_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_last_segment), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.previous_segment_addr), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.previous_segment_step), + ); + air_instance.set_airvalue( + format!("{}.segment_last_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_last_addr), + ); + air_instance.set_airvalue( + format!("{}.segment_last_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.segment_last_step), + ); + let count = air_values.previous_segment_value.len(); + for i in 0..count { + air_instance.set_airvalue( + format!("{}.previous_segment_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.previous_segment_value[i]), + ); + air_instance.set_airvalue( + format!("{}.segment_last_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.segment_last_value[i]), + ); + } + } +} + +impl MemModule for RomDataSM { + fn send_inputs(&self, mem_op: &[MemInput]) { + self.prove(mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(ROM_ADDR as u32, ROM_ADDR_MAX as u32)] + } + fn get_flush_input_size(&self) -> u32 { + self.num_rows as u32 + } +} + +impl WitnessComponent for RomDataSM {} diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 8d89c2dc..52a2b591 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -9,16 +9,18 @@ use rayon::prelude::*; use sm_arith::ArithSM; use sm_binary::BinarySM; use sm_main::{InstanceExtensionCtx, MainSM}; -use sm_mem::MemSM; +use sm_mem::MemProxy; use sm_rom::RomSM; use std::{ fs, path::{Path, PathBuf}, sync::Arc, + thread, }; use zisk_core::{Riscv2zisk, ZiskOperationType, ZiskRom, ZISK_OPERATION_TYPE_VARIANTS}; use zisk_pil::{ - ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, + ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ROM_AIR_IDS, + ZISK_AIRGROUP_ID, }; use ziskemu::{EmuOptions, ZiskEmulator}; @@ -33,7 +35,7 @@ pub struct ZiskExecutor { pub rom_sm: Arc>, /// Memory State Machine - pub mem_sm: Arc, + pub mem_proxy_sm: Arc>, /// Binary State Machine pub binary_sm: Arc>, @@ -49,7 +51,7 @@ impl ZiskExecutor { let std = Std::new(wcm.clone()); let rom_sm = RomSM::new(wcm.clone()); - let mem_sm = MemSM::new(wcm.clone()); + let mem_proxy_sm = MemProxy::new(wcm.clone(), std.clone()); let binary_sm = BinarySM::new(wcm.clone(), std.clone()); let arith_sm = ArithSM::new(wcm.clone(), binary_sm.clone()); @@ -81,9 +83,10 @@ impl ZiskExecutor { // TODO - If there is more than one Main AIR available, the MAX_ACCUMULATED will be the one // with the highest num_rows. It has to be a power of 2. - let main_sm = MainSM::new(wcm.clone(), arith_sm.clone(), binary_sm.clone(), mem_sm.clone()); + let main_sm = + MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); - Self { zisk_rom, main_sm, rom_sm, mem_sm, binary_sm, arith_sm } + Self { zisk_rom, main_sm, rom_sm, mem_proxy_sm, binary_sm, arith_sm } } /// Executes the MainSM state machine and processes the inputs in batches when the maximum @@ -118,6 +121,7 @@ impl ZiskExecutor { let path = PathBuf::from(public_inputs_path.display().to_string()); fs::read(path).expect("Could not read inputs file") }; + let public_inputs = Arc::new(public_inputs); // During ROM processing, we gather execution data necessary for creating the AIR instances. // This data is collected by the emulator and includes the minimal execution trace, @@ -137,17 +141,36 @@ impl ZiskExecutor { op_sizes[ZiskOperationType::Binary as usize] = air_binary.num_rows() as u64; op_sizes[ZiskOperationType::BinaryE as usize] = air_binary_e.num_rows() as u64; + // STEP 1. Generate all inputs + // ============================================== + + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + move || { + ZiskEmulator::par_process_rom_memory::(&zisk_rom, &public_inputs) + .expect("Failed in ZiskEmulator::par_process_rom_memory") + } + }); + // ROM State Machine // ---------------------------------------------- // Run the ROM to compute the ROM witness - let rom_sm = self.rom_sm.clone(); - let zisk_rom = self.zisk_rom.clone(); - let pc_histogram = - ZiskEmulator::process_rom_pc_histogram(&self.zisk_rom, &public_inputs, &emu_options) - .expect( - "MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()", - ); - let handle_rom = std::thread::spawn(move || rom_sm.prove(&zisk_rom, pc_histogram)); + let rom_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + let emu_options_cloned = emu_options.clone(); + move || { + ZiskEmulator::process_rom_pc_histogram( + &zisk_rom, + &public_inputs, + &emu_options_cloned, + ) + .expect("MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()") + } + }); // Main, Binary and Arith State Machines // ---------------------------------------------- @@ -164,10 +187,43 @@ impl ZiskExecutor { .expect("Error during emulator execution"); timer_stop_and_log_debug!(PAR_PROCESS_ROM); - emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); + // STEP 2. Wait until all inputs are generated + // ============================================== + // Join all the threads to synchronize the execution + let mut mem_required = mem_thread.join().expect("Error during Memory witness computation"); + let rom_required = rom_thread.join().expect("Error during ROM witness computation"); + + // STEP 3. Generate AIRs and Prove + // ============================================== - // Join threads to synchronize the execution - handle_rom.join().unwrap().expect("Error during ROM witness computation"); + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let mem_proxy_sm = self.mem_proxy_sm.clone(); + move || { + mem_proxy_sm + .prove(&mut mem_required) + .expect("Error during Memory witness computation") + } + }); + + // ROM State Machine + // ---------------------------------------------- + let (rom_is_mine, _rom_instance_gid) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], 1); + + let rom_thread = if rom_is_mine { + let rom_sm = self.rom_sm.clone(); + let zisk_rom = self.zisk_rom.clone(); + + Some(thread::spawn(move || rom_sm.prove(&zisk_rom, rom_required))) + } else { + None + }; + + // Main, Binary and Arith State Machines + // ---------------------------------------------- + emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); // FIXME: Move InstanceExtensionCtx form main SM to another place let mut instances_extension_ctx: Vec> = @@ -232,7 +288,28 @@ impl ZiskExecutor { } timer_stop_and_log_debug!(ADD_INSTANCES_TO_THE_REPO); - // self.mem_sm.unregister_predecessor(scope); + mem_thread.join().expect("Error during Memory witness computation"); + + // match mem_thread.join() { + // Ok(_) => println!("El thread ha finalitzat correctament."), + // Err(e) => { + // println!("El thread ha fet panic!"); + // + // // Converteix l'error en una cadena llegible (opcional) + // if let Some(missatge) = e.downcast_ref::<&str>() { + // println!("Missatge d'error: {}", missatge); + // } else if let Some(missatge) = e.downcast_ref::() { + // println!("Missatge d'error: {}", missatge); + // } else { + // println!("No es pot determinar el tipus d'error."); + // } + // } + // } + if let Some(thread) = rom_thread { + let _ = thread.join().expect("Error during ROM witness computation"); + } + + self.mem_proxy_sm.unregister_predecessor(); self.binary_sm.unregister_predecessor(); self.arith_sm.unregister_predecessor(); } From 0c063e0d78b682281795a0924c2ea2aded29f37f Mon Sep 17 00:00:00 2001 From: zkronos73 <94566827+zkronos73@users.noreply.github.com> Date: Thu, 12 Dec 2024 23:48:25 +0100 Subject: [PATCH 5/6] add proofvalue to enable/disable input_data memory (#193) --- Cargo.lock | 28 ++++++++++++------------ Cargo.toml | 24 ++++++++++---------- emulator/src/emu.rs | 4 ++-- pil/zisk.pil | 7 ++++-- state-machines/arith/pil/arith_table.pil | 2 +- state-machines/mem/pil/mem.pil | 8 +++++-- state-machines/mem/src/input_data_sm.rs | 9 +++++++- 7 files changed, 48 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cd9341db..24c1ea80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1461,7 +1461,7 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "log", "num-bigint", @@ -1479,7 +1479,7 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "bytes", "log", @@ -1599,7 +1599,7 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "colored", "env_logger", @@ -1620,7 +1620,7 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "env_logger", "log", @@ -1639,7 +1639,7 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "p3-field", "proofman-common", @@ -1649,7 +1649,7 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "proc-macro2", "quote", @@ -1659,7 +1659,7 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "log", ] @@ -1667,7 +1667,7 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "colored", "sysinfo 0.31.4", @@ -1847,9 +1847,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ "bitflags 2.6.0", ] @@ -2014,9 +2014,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.19" +version = "0.23.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" +checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" dependencies = [ "once_cell", "ring", @@ -2316,7 +2316,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "log", "p3-field", @@ -2665,7 +2665,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.16#5c47437feffccb16d95e120e8336ab8a168314e7" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=fix%2Fstarkpil-memcpy-airvalue-proofvalue#a6392a7f20056fa8bff95d60da3b7adabc7c5c74" dependencies = [ "proofman-starks-lib-c", ] diff --git a/Cargo.toml b/Cargo.toml index 9f456214..b10f41b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,19 +26,19 @@ opt-level = 3 opt-level = 3 [workspace.dependencies] -proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } -proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } -proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } -proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } -pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } -stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.16" } +proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } +proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } +proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } +proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } +pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } +stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } #Local development -# proofman-common = { path = "../pil2-proofman/common" } -# proofman-macros = { path = "../pil2-proofman/macros" } -# proofman-util = { path = "../pil2-proofman/util" } -# proofman = { path = "../pil2-proofman/proofman" } -# pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } -# stark = { path = "../pil2-proofman/provers/stark" } +#proofman-common = { path = "../pil2-proofman/common" } +#proofman-macros = { path = "../pil2-proofman/macros" } +#proofman-util = { path = "../pil2-proofman/util" } +#proofman = { path = "../pil2-proofman/proofman" } +#pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } +#stark = { path = "../pil2-proofman/provers/stark" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } log = "0.4" diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 6ec7fa0f..93c6c1d5 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -415,7 +415,7 @@ impl<'a> Emu<'a> { } } - /// Set SP, if specified by the current instruction + // Set SP, if specified by the current instruction // #[cfg(feature = "sp")] // #[inline(always)] // pub fn set_sp(&mut self, instruction: &ZiskInst) { @@ -426,7 +426,7 @@ impl<'a> Emu<'a> { // } // } - /// Set PC, based on current PC, current flag and current instruction + // Set PC, based on current PC, current flag and current instruction #[inline(always)] pub fn set_pc(&mut self, instruction: &ZiskInst) { if instruction.set_pc { diff --git a/pil/zisk.pil b/pil/zisk.pil index bce36ffa..3a4d1111 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -11,16 +11,19 @@ require "arith/pil/arith.pil" const int OPERATION_BUS_ID = 5000; +proofval enable_input_data; +enable_input_data * (1 - enable_input_data); + airgroup Zisk { Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); Rom(N: 2**22); Mem(N: 2**21, RC: 2, base_address: 0xA000_0000); Mem(N: 2**21, RC: 2, base_address: 0x8000_0000, immutable: 1) alias RomData; - Mem(N: 2**21, RC: 2, base_address: 0x9000_0000, free_input_mem: 1) alias InputData; + Mem(N: 2**21, RC: 2, base_address: 0x9000_0000, free_input_mem: 1, enable_flag: enable_input_data) alias InputData; + MemAlign(N: 2**21); MemAlignRom(disable_fixed: 0); - // InputData(N: 2**21, RC: 2); Arith(N: 2**21, operation_bus_id: OPERATION_BUS_ID); ArithTable(); diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index e8bd35d7..53f8b238 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -2,7 +2,7 @@ require "std_lookup.pil" const int ARITH_TABLE_ID = 331; -airtemplate ArithTable(int N = 2**7, int generate_table = 1) { +airtemplate ArithTable(int N = 2**7, int generate_table = 0) { // div m32 sa sb primary secondary opcodes na nb np nr sext(c) // ----------------------------------------------------------------------------------- diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index f574f264..27052d91 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -50,7 +50,10 @@ const int MEMORY_MAX_DIFF = 2**24; const int MAX_MEM_STEP_OFFSET = 2; const int MAX_MEM_OPS_PER_MAIN_STEP = (MAX_MEM_STEP_OFFSET + 1) * 2; -airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, const int mem_bytes = 8, const int base_address = 0, const int mem_size = 0x800_0000, int immutable = 0, const int free_input_mem = 0) { +airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, const int mem_bytes = 8, + const int base_address = 0, const int mem_size = 0x800_0000, int immutable = 0, + const int free_input_mem = 0, const expr enable_flag = 1) { + col fixed SEGMENT_L1 = [1,0...]; const expr SEGMENT_LAST = SEGMENT_L1'; @@ -146,7 +149,8 @@ airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, for (int i = 0; i < length(zeros); ++i) { zeros[i] = 0; } - direct_global_update_proves(MEMORY_CONT_ID, [ base_address, 0, internal_base_address, 0, ...zeros]); + + direct_global_update_proves(MEMORY_CONT_ID, [ base_address, 0, internal_base_address, 0, ...zeros], sel: enable_flag); // for security check that first address has correct value, to avoid add huge quantity of instances to "overflow" prime field. range_check(colu: previous_segment_addr - internal_base_address + 1, min: 1, max: MEMORY_MAX_DIFF); diff --git a/state-machines/mem/src/input_data_sm.rs b/state-machines/mem/src/input_data_sm.rs index 220fc33f..2717e1e6 100644 --- a/state-machines/mem/src/input_data_sm.rs +++ b/state-machines/mem/src/input_data_sm.rs @@ -74,10 +74,17 @@ impl InputDataSM { self.std.unregister_predecessor(pctx, None); } } - pub fn prove(&self, inputs: &[MemInput]) { let wcm = self.wcm.clone(); let pctx = wcm.get_pctx(); + + if (inputs.is_empty()) { + pctx.set_proof_value("enable_input_data", F::zero()); + return; + } + + pctx.set_proof_value("enable_input_data", F::one()); + let ectx = wcm.get_ectx(); let sctx = wcm.get_sctx(); From 765bf0b00eab9dd2ffb1363b9e49432e0aac2ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Mon, 16 Dec 2024 07:37:44 +0000 Subject: [PATCH 6/6] updating cargo toml --- Cargo.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b10f41b6..524ff00d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,12 +26,12 @@ opt-level = 3 opt-level = 3 [workspace.dependencies] -proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } -proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } -proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } -proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } -pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } -stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch = "fix/starkpil-memcpy-airvalue-proofvalue" } +proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.17-pre2" } +proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.17-pre2" } +proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.17-pre2" } +proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.17-pre2" } +pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.17-pre2" } +stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.17-pre2" } #Local development #proofman-common = { path = "../pil2-proofman/common" } #proofman-macros = { path = "../pil2-proofman/macros" }