From 4ce782d62ab36b778d3988c69a359f9aa9824cdf Mon Sep 17 00:00:00 2001 From: AlexandreBelling Date: Thu, 24 Oct 2024 18:42:06 +0200 Subject: [PATCH 1/7] feat(mimc): adds a State and SetState functionality --- std/hash/mimc/mimc.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index 1d6fa4d35c..f0887c1756 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -61,6 +61,29 @@ func (h *MiMC) Reset() { h.h = 0 } +// SetState manually sets the state of the hasher to the provided value. In the +// case of MiMC only a single frontend variable is expected to represent the +// state. +func (h *MiMC) SetState(newState []frontend.Variable) error { + + if len(h.data) > 0 { + return errors.New("the hasher is not in an initial state") + } + + if len(newState) != 1 { + return errors.New("the MiMC hasher expects a single field element to represent the state") + } + + h.h = newState[0] + return nil +} + +// State returns the inner-state of the hasher. In the context of MiMC only a +// single field element is returned. +func (h *MiMC) State() []frontend.Variable { + return []frontend.Variable{h.h} +} + // Sum hash using [Miyaguchi–Preneel] where the XOR operation is replaced by // field addition. // From 88f16ef6750df0e58be287ac87796862b46b5f0c Mon Sep 17 00:00:00 2001 From: AlexandreBelling Date: Thu, 24 Oct 2024 18:43:29 +0200 Subject: [PATCH 2/7] fix: ensures the state is flushed --- std/hash/mimc/mimc.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index f0887c1756..1d8013ab13 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -75,12 +75,14 @@ func (h *MiMC) SetState(newState []frontend.Variable) error { } h.h = newState[0] + h.data = nil return nil } // State returns the inner-state of the hasher. In the context of MiMC only a // single field element is returned. func (h *MiMC) State() []frontend.Variable { + h.Sum() // this flushes the unsummed data return []frontend.Variable{h.h} } From 33fb4755517dee9a44e3439e92e6737c08aab7b1 Mon Sep 17 00:00:00 2001 From: AlexandreBelling Date: Thu, 12 Dec 2024 00:16:57 +0100 Subject: [PATCH 3/7] change the implementation to match the one of gnark crypto and adds a test --- std/hash/hash.go | 15 ++++++++ std/hash/mimc/mimc_test.go | 72 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/std/hash/hash.go b/std/hash/hash.go index 24854c577f..6fc8752aa7 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -40,6 +40,21 @@ type FieldHasher interface { Reset() } +// StateStorer allows to store and retrieve the state of a hash function. +type StateStorer interface { + FieldHasher + // State retrieves the current state of the hash function. Calling this + // method should not destroy the current state and allow continue the use of + // the current hasher. The number of returned Variable is implementation + // dependent. + State() []frontend.Variable + // SetState sets the state of the hash function from a previously stored + // state retrieved using [StateStorer.State] method. The implementation + // returns an error if the number of supplied Variable does not match the + // number of Variable expected. + SetState(state []frontend.Variable) error +} + var ( builderRegistry = make(map[string]func(api frontend.API) (FieldHasher, error)) lock sync.RWMutex diff --git a/std/hash/mimc/mimc_test.go b/std/hash/mimc/mimc_test.go index 6739be9359..a3e48112f8 100644 --- a/std/hash/mimc/mimc_test.go +++ b/std/hash/mimc/mimc_test.go @@ -17,6 +17,8 @@ limitations under the License. package mimc import ( + "errors" + "fmt" "math/big" "testing" @@ -93,3 +95,73 @@ func TestMimcAll(t *testing.T) { } } + +// stateStoreCircuit checks that SetState works as expected. The circuit, however +// does not check the correctness of the hashes returned by the MiMC function +// as there is another test already testing this property. +type stateStoreTestCircuit struct { + X frontend.Variable +} + +func (s *stateStoreTestCircuit) Define(api frontend.API) error { + + hsh1, err1 := NewMiMC(api) + hsh2, err2 := NewMiMC(api) + + if err1 != nil || err2 != nil { + return fmt.Errorf("could not instantiate the MIMC hasher: %w", errors.Join(err1, err2)) + } + + // This pre-shuffle the hasher state so that the test does not start from + // a zero state. + hsh1.Write(s.X) + + state := hsh1.State() + hsh2.SetState(state) + + hsh1.Write(s.X) + hsh2.Write(s.X) + + var ( + dig1 = hsh1.Sum() + dig2 = hsh2.Sum() + newState1 = hsh1.State() + newState2 = hsh2.State() + ) + + api.AssertIsEqual(dig1, dig2) + + for i := range newState1 { + api.AssertIsEqual(newState1[i], newState2[i]) + } + + return nil +} + +func TestStateStoreMiMC(t *testing.T) { + + assert := test.NewAssert(t) + + curves := map[ecc.ID]hash.Hash{ + ecc.BN254: hash.MIMC_BN254, + ecc.BLS12_381: hash.MIMC_BLS12_381, + ecc.BLS12_377: hash.MIMC_BLS12_377, + ecc.BW6_761: hash.MIMC_BW6_761, + ecc.BW6_633: hash.MIMC_BW6_633, + ecc.BLS24_315: hash.MIMC_BLS24_315, + ecc.BLS24_317: hash.MIMC_BLS24_317, + } + + for curve := range curves { + + // minimal cs res = hash(data) + var ( + circuit = &stateStoreTestCircuit{} + assignment = &stateStoreTestCircuit{X: 2} + ) + + assert.CheckCircuit(circuit, + test.WithValidAssignment(assignment), + test.WithCurves(curve)) + } +} From 1d6159381b0b1c67a6f18894c636c3852466ad1e Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 16 Dec 2024 11:50:11 +0000 Subject: [PATCH 4/7] chore: gnark-crypto update --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 6e0371490f..3d5038a5c8 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.24 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.14.1-0.20241122181107-03e007d865c0 + github.com/consensys/gnark-crypto v0.14.1-0.20241211083239-be3c2bbb1724 github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index 5a48cfc662..2044322751 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.24 h1:Lfe+bjYbpaoT7K5JTFoMi5wo9V4REGLvQQbHmatoN github.com/consensys/bavard v0.1.24/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.14.1-0.20241122181107-03e007d865c0 h1:uFZaZWG0FOoiFN3fAQzH2JXDuybdNwiJzBujy81YtU4= -github.com/consensys/gnark-crypto v0.14.1-0.20241122181107-03e007d865c0/go.mod h1:F/hJyWBcTr1sWeifAKfEN3aVb3G4U5zheEC8IbWQun4= +github.com/consensys/gnark-crypto v0.14.1-0.20241211083239-be3c2bbb1724 h1:lfTzZSy3FG2z5qFfRihDHmuolUvyEBWW8gsrjlZJQ/I= +github.com/consensys/gnark-crypto v0.14.1-0.20241211083239-be3c2bbb1724/go.mod h1:ePFa23CZLMRMHxQpY5nMaiAZ3yuEIayaB8ElEvlwLEs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= From 086befdde609639cc2fe09cf232fc9426b538fc5 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 16 Dec 2024 11:50:30 +0000 Subject: [PATCH 5/7] test: hash state serialization compatibility with gnark-crypto --- std/hash/mimc/mimc_test.go | 56 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/std/hash/mimc/mimc_test.go b/std/hash/mimc/mimc_test.go index 3667fd82b0..019862cb34 100644 --- a/std/hash/mimc/mimc_test.go +++ b/std/hash/mimc/mimc_test.go @@ -4,6 +4,7 @@ package mimc import ( + "crypto/rand" "errors" "fmt" "math/big" @@ -152,3 +153,58 @@ func TestStateStoreMiMC(t *testing.T) { test.WithCurves(curve)) } } + +type recoveredStateTestCircuit struct { + State []frontend.Variable + Input frontend.Variable + Expected frontend.Variable `gnark:",public"` +} + +func (c *recoveredStateTestCircuit) Define(api frontend.API) error { + h, err := NewMiMC(api) + if err != nil { + return fmt.Errorf("initialize hash: %w", err) + } + if err = h.SetState(c.State); err != nil { + return fmt.Errorf("set state: %w", err) + } + h.Write(c.Input) + res := h.Sum() + api.AssertIsEqual(res, c.Expected) + return nil +} + +func TestHasherFromState(t *testing.T) { + assert := test.NewAssert(t) + + hashes := map[ecc.ID]hash.Hash{ + ecc.BN254: hash.MIMC_BN254, + ecc.BLS12_381: hash.MIMC_BLS12_381, + ecc.BLS12_377: hash.MIMC_BLS12_377, + ecc.BW6_761: hash.MIMC_BW6_761, + ecc.BW6_633: hash.MIMC_BW6_633, + ecc.BLS24_315: hash.MIMC_BLS24_315, + ecc.BLS24_317: hash.MIMC_BLS24_317, + } + + for cc, hh := range hashes { + hasher := hh.New() + ss, ok := hasher.(hash.StateStorer) + assert.True(ok) + _, err := ss.Write([]byte("hello world")) + assert.NoError(err) + state := ss.State() + nbBytes := cc.ScalarField().BitLen() / 8 + buf := make([]byte, nbBytes) + _, err = rand.Read(buf) + assert.NoError(err) + ss.Write(buf) + expected := ss.Sum(nil) + bstate := new(big.Int).SetBytes(state) + binput := new(big.Int).SetBytes(buf) + assert.CheckCircuit( + &recoveredStateTestCircuit{State: make([]frontend.Variable, 1)}, + test.WithValidAssignment(&recoveredStateTestCircuit{State: []frontend.Variable{bstate}, Input: binput, Expected: expected}), + test.WithCurves(cc)) + } +} From 190c9569889c6f9a79de6d69bd32dc61e23bc57a Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 16 Dec 2024 11:57:37 +0000 Subject: [PATCH 6/7] chore: update generated tinyfield due to dep update --- internal/tinyfield/doc.go | 12 +- internal/tinyfield/element.go | 192 +++++++++------------------ internal/tinyfield/element_purego.go | 70 ++++++++++ internal/tinyfield/element_test.go | 52 ++------ internal/tinyfield/vector.go | 41 +----- internal/tinyfield/vector_purego.go | 43 ++++++ internal/tinyfield/vector_test.go | 10 +- 7 files changed, 203 insertions(+), 217 deletions(-) create mode 100644 internal/tinyfield/element_purego.go create mode 100644 internal/tinyfield/vector_purego.go diff --git a/internal/tinyfield/doc.go b/internal/tinyfield/doc.go index a8b6fce697..32f5ad92c0 100644 --- a/internal/tinyfield/doc.go +++ b/internal/tinyfield/doc.go @@ -1,17 +1,19 @@ -// Copyright 2020-2024 ConsenSys Software Inc. +// Copyright 2020-2024 Consensys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT // Package tinyfield contains field arithmetic operations for modulus = 0x2f. // -// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). +// +// Additionally tinyfield.Vector offers an API to manipulate []Element. // // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: // -// type Element [1]uint64 +// type Element [1]uint32 // // # Usage // @@ -38,5 +40,7 @@ // // # Warning // -// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +// There is no security guarantees such as constant time implementation or side-channel attack resistance. +// This code is provided as-is. Partially audited, see https://github.com/Consensys/gnark/tree/master/audits +// for more details. package tinyfield diff --git a/internal/tinyfield/element.go b/internal/tinyfield/element.go index 5d7e45ae33..96b7ad573f 100644 --- a/internal/tinyfield/element.go +++ b/internal/tinyfield/element.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 ConsenSys Software Inc. +// Copyright 2020-2024 Consensys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/field/pool" ) -// Element represents a field element stored on 1 words (uint64) +// Element represents a field element stored on 1 words (uint32) // // Element are assumed to be in Montgomery form in all methods. // @@ -33,18 +33,18 @@ import ( // # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. -type Element [1]uint64 +type Element [1]uint32 const ( - Limbs = 1 // number of 64 bits words needed to represent a Element + Limbs = 1 // number of 32 bits words needed to represent a Element Bits = 6 // number of bits needed to represent a Element - Bytes = 8 // number of bytes needed to represent a Element + Bytes = 4 // number of bytes needed to represent a Element ) // Field modulus q const ( - q0 uint64 = 47 - q uint64 = q0 + q0 = 47 + q = q0 ) var qElement = Element{ @@ -63,7 +63,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg uint64 = 12559485326780971313 +const qInvNeg = 2558703921 func init() { _modulus.SetString("2f", 16) @@ -76,16 +76,16 @@ func init() { // var v Element // v.SetUint64(...) func NewElement(v uint64) Element { - z := Element{v} - z.Mul(&z, &rSquare) + z := Element{uint32(v % uint64(q0))} + z.toMont() return z } // SetUint64 sets z to v and returns z func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form - *z = Element{v} - return z.Mul(z, &rSquare) // z.toMont() + *z = Element{uint32(v % uint64(q0))} + return z.toMont() } // SetInt64 sets z to v and returns z @@ -178,7 +178,7 @@ func (z *Element) SetZero() *Element { // SetOne z = 1 (in Montgomery form) func (z *Element) SetOne() *Element { - z[0] = 25 + z[0] = 42 return z } @@ -196,7 +196,7 @@ func (z *Element) Equal(x *Element) bool { } // NotEqual returns 0 if and only if z == x; constant-time -func (z *Element) NotEqual(x *Element) uint64 { +func (z *Element) NotEqual(x *Element) uint32 { return (z[0] ^ x[0]) } @@ -207,7 +207,7 @@ func (z *Element) IsZero() bool { // IsOne returns z == 1 func (z *Element) IsOne() bool { - return z[0] == 25 + return z[0] == 42 } // IsUint64 reports whether z can be represented as an uint64. @@ -217,7 +217,7 @@ func (z *Element) IsUint64() bool { // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - return z.Bits()[0] + return uint64(z.Bits()[0]) } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -252,8 +252,8 @@ func (z *Element) LexicographicallyLargest() bool { _z := z.Bits() - var b uint64 - _, b = bits.Sub64(_z[0], 24, 0) + var b uint32 + _, b = bits.Sub32(_z[0], 24, 0) return b == 0 } @@ -292,7 +292,7 @@ func (z *Element) SetRandom() (*Element, error) { // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<> 1 @@ -338,35 +338,31 @@ func (z *Element) fromMont() *Element { // Add z = x + y (mod q) func (z *Element) Add(x, y *Element) *Element { - z[0], _ = bits.Add64(x[0], y[0], 0) - if z[0] >= q { - z[0] -= q + t := x[0] + y[0] + if t >= q { + t -= q } + z[0] = t return z } // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { - if x[0]&(1<<63) == (1 << 63) { - // if highest bit is set, then we have a carry to x + x, we shift and subtract q - z[0] = (x[0] << 1) - q - } else { - // highest bit is not set, but x + x can still be >= q - z[0] = (x[0] << 1) - if z[0] >= q { - z[0] -= q - } + t := x[0] << 1 + if t >= q { + t -= q } + z[0] = t return z } // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) + t, b := bits.Sub32(x[0], y[0], 0) if b != 0 { - z[0] += q + t += q } + z[0] = t return z } @@ -383,69 +379,13 @@ func (z *Element) Neg(x *Element) *Element { // Select is a constant-time conditional move. // If c=0, z = x0. Else z = x1 func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { - cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + cC := uint32((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise z[0] = x0[0] ^ cC&(x0[0]^x1[0]) return z } -// _mulGeneric is unoptimized textbook CIOS -// it is a fallback solution on x86 when ADX instruction set is not available -// and is used for testing purposes. -func _mulGeneric(z, x, y *Element) { - - // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" - // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 - - var t [2]uint64 - var D uint64 - var m, C uint64 - // ----------------------------------- - // First loop - - C, t[0] = bits.Mul64(y[0], x[0]) - - t[1], D = bits.Add64(t[1], C, 0) - - // m = t[0]n'[0] mod W - m = t[0] * qInvNeg - - // ----------------------------------- - // Second loop - C = madd0(m, q0, t[0]) - - t[0], C = bits.Add64(t[1], C, 0) - t[1], _ = bits.Add64(0, D, C) - - if t[1] != 0 { - // we need to reduce, we have a result on 2 words - z[0], _ = bits.Sub64(t[0], q0, 0) - return - } - - // copy t into z - z[0] = t[0] - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - z[0] -= q - } -} - func _fromMontGeneric(z *Element) { - // the following lines implement z = z * 1 - // with a modified CIOS montgomery multiplication - // see Mul for algorithm documentation - { - // m = z[0]n'[0] mod W - m := z[0] * qInvNeg - C := madd0(m, q0, z[0]) - z[0] = C - } - - // if z ⩾ q → z -= q - if !z.smallerThanModulus() { - z[0] -= q - } + z[0] = montReduce(uint64(z[0])) } func _reduceGeneric(z *Element) { @@ -498,7 +438,7 @@ func _butterflyGeneric(a, b *Element) { // BitLen returns the minimum number of bits needed to represent z // returns 0 if z == 0 func (z *Element) BitLen() int { - return bits.Len64(z[0]) + return bits.Len32(z[0]) } // Hash msg to count prime field elements. @@ -565,13 +505,15 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // see section 2.3.2 of Tolga Acar's thesis // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf var rSquare = Element{ - 14, + 25, } // toMont converts z to Montgomery form // sets and returns z = z * r² func (z *Element) toMont() *Element { - return z.Mul(z, &rSquare) + const rBits = 32 + z[0] = uint32((uint64(z[0]) << rBits) % q) + return z } // String returns the decimal representation of z as generated by @@ -583,7 +525,7 @@ func (z *Element) String() string { // toBigInt returns z as a big.Int in Montgomery form func (z *Element) toBigInt(res *big.Int) *big.Int { var b [Bytes]byte - binary.BigEndian.PutUint64(b[0:8], z[0]) + binary.BigEndian.PutUint32(b[0:4], z[0]) return res.SetBytes(b[:]) } @@ -603,7 +545,7 @@ func (z *Element) Text(base int) string { const maxUint16 = 65535 zz := z.Bits() - return strconv.FormatUint(zz[0], base) + return strconv.FormatUint(uint64(zz[0]), base) } // BigInt sets and return z as a *big.Int @@ -621,10 +563,10 @@ func (z Element) ToBigIntRegular(res *big.Int) *big.Int { return z.toBigInt(res) } -// Bits provides access to z by returning its value as a little-endian [1]uint64 array. +// Bits provides access to z by returning its value as a little-endian [1]uint32 array. // Bits is intended to support implementation of missing low-level Element // functionality outside this package; it should be avoided otherwise. -func (z *Element) Bits() [1]uint64 { +func (z *Element) Bits() [1]uint32 { _z := *z fromMont(&_z) return _z @@ -673,8 +615,8 @@ func (z *Element) SetBytes(e []byte) *Element { return z } -// SetBytesCanonical interprets e as the bytes of a big-endian 8-byte integer. -// If e is not a 8-byte slice or encodes a value higher than q, +// SetBytesCanonical interprets e as the bytes of a big-endian 4-byte integer. +// If e is not a 4-byte slice or encodes a value higher than q, // SetBytesCanonical returns an error. func (z *Element) SetBytesCanonical(e []byte) error { if len(e) != Bytes { @@ -721,19 +663,9 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - - if bits.UintSize == 64 { - for i := 0; i < len(vBits); i++ { - z[i] = uint64(vBits[i]) - } - } else { - for i := 0; i < len(vBits); i++ { - if i%2 == 0 { - z[i/2] = uint64(vBits[i]) - } else { - z[i/2] |= uint64(vBits[i]) << 32 - } - } + // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits + for i := 0; i < len(vBits); i++ { + z[i] = uint32(vBits[i]) } return z.toMont() @@ -832,11 +764,11 @@ var BigEndian bigEndian type bigEndian struct{} -// Element interpret b is a big-endian 8-byte slice. +// Element interpret b is a big-endian 4-byte slice. // If b encodes a value higher than q, Element returns error. func (bigEndian) Element(b *[Bytes]byte) (Element, error) { var z Element - z[0] = binary.BigEndian.Uint64((*b)[0:8]) + z[0] = binary.BigEndian.Uint32((*b)[0:4]) if !z.smallerThanModulus() { return Element{}, errors.New("invalid tinyfield.Element encoding") @@ -848,7 +780,7 @@ func (bigEndian) Element(b *[Bytes]byte) (Element, error) { func (bigEndian) PutElement(b *[Bytes]byte, e Element) { e.fromMont() - binary.BigEndian.PutUint64((*b)[0:8], e[0]) + binary.BigEndian.PutUint32((*b)[0:4], e[0]) } func (bigEndian) String() string { return "BigEndian" } @@ -860,7 +792,7 @@ type littleEndian struct{} func (littleEndian) Element(b *[Bytes]byte) (Element, error) { var z Element - z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[0] = binary.LittleEndian.Uint32((*b)[0:4]) if !z.smallerThanModulus() { return Element{}, errors.New("invalid tinyfield.Element encoding") @@ -872,7 +804,7 @@ func (littleEndian) Element(b *[Bytes]byte) (Element, error) { func (littleEndian) PutElement(b *[Bytes]byte, e Element) { e.fromMont() - binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint32((*b)[0:4], e[0]) } func (littleEndian) String() string { return "LittleEndian" } @@ -926,19 +858,19 @@ func (z *Element) Sqrt(x *Element) *Element { // if x == 0, sets and returns z = x func (z *Element) Inverse(x *Element) *Element { // Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" - const q uint64 = q0 + const q uint32 = q0 if x.IsZero() { z.SetZero() return z } - var r, s, u, v uint64 + var r, s, u, v uint32 u = q - s = 14 // s = r² + s = 25 // s = r² r = 0 v = x[0] - var carry, borrow uint64 + var carry, borrow uint32 for (u != 1) && (v != 1) { for v&1 == 0 { @@ -946,10 +878,10 @@ func (z *Element) Inverse(x *Element) *Element { if s&1 == 0 { s >>= 1 } else { - s, carry = bits.Add64(s, q, 0) + s, carry = bits.Add32(s, q, 0) s >>= 1 if carry != 0 { - s |= (1 << 63) + s |= (1 << 31) } } } @@ -958,22 +890,22 @@ func (z *Element) Inverse(x *Element) *Element { if r&1 == 0 { r >>= 1 } else { - r, carry = bits.Add64(r, q, 0) + r, carry = bits.Add32(r, q, 0) r >>= 1 if carry != 0 { - r |= (1 << 63) + r |= (1 << 31) } } } if v >= u { v -= u - s, borrow = bits.Sub64(s, r, 0) + s, borrow = bits.Sub32(s, r, 0) if borrow == 1 { s += q } } else { u -= v - r, borrow = bits.Sub64(r, s, 0) + r, borrow = bits.Sub32(r, s, 0) if borrow == 1 { r += q } diff --git a/internal/tinyfield/element_purego.go b/internal/tinyfield/element_purego.go new file mode 100644 index 0000000000..301cd8589e --- /dev/null +++ b/internal/tinyfield/element_purego.go @@ -0,0 +1,70 @@ +// Copyright 2020-2024 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package tinyfield + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + var y Element + y.SetUint64(3) + x.Mul(x, &y) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + var y Element + y.SetUint64(5) + x.Mul(x, &y) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y Element + y.SetUint64(13) + x.Mul(x, &y) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} +func montReduce(v uint64) uint32 { + m := uint32(v) * qInvNeg + t := uint32((v + uint64(m)*q) >> 32) + if t >= q { + t -= q + } + return t +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + v := uint64(x[0]) * uint64(y[0]) + z[0] = montReduce(v) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + v := uint64(x[0]) * uint64(x[0]) + z[0] = montReduce(v) + return z +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} diff --git a/internal/tinyfield/element_test.go b/internal/tinyfield/element_test.go index 64d9667a54..7894b2a0cc 100644 --- a/internal/tinyfield/element_test.go +++ b/internal/tinyfield/element_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 ConsenSys Software Inc. +// Copyright 2020-2024 Consensys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -192,7 +192,7 @@ func BenchmarkElementSqrt(b *testing.B) { func BenchmarkElementMul(b *testing.B) { x := Element{ - 14, + 25, } benchResElement.SetOne() b.ResetTimer() @@ -203,7 +203,7 @@ func BenchmarkElementMul(b *testing.B) { func BenchmarkElementCmp(b *testing.B) { x := Element{ - 14, + 25, } benchResElement = x benchResElement[0] = 0 @@ -921,14 +921,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a.element, &r) d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) - // checking generic impl against asm path - var cGeneric Element - _mulGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - if c.BigInt(&e).Cmp(&d) != 0 { return false } @@ -951,17 +943,6 @@ func TestElementMul(t *testing.T) { genB, )) - properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Mul(&a.element, &b.element) - _mulGeneric(&d, &a.element, &b.element) - return c.Equal(&d) - }, - genA, - genB, - )) - specialValueTest := func() { // test special values against special values testValues := make([]Element, len(staticTestValues)) @@ -980,13 +961,6 @@ func TestElementMul(t *testing.T) { c.Mul(&a, &b) d.Mul(&aBig, &bBig).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _mulGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Mul failed special test values: asm and generic impl don't match") - } - if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } @@ -2126,17 +2100,17 @@ func gen() gopter.Gen { var g testPairElement g.element = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g.element[0] %= (qElement[0] + 1) } for !g.element.smallerThanModulus() { g.element = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g.element[0] %= (qElement[0] + 1) } } @@ -2151,18 +2125,18 @@ func genRandomFq(genParams *gopter.GenParameters) Element { var g Element g = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g[0] %= (qElement[0] + 1) } for !g.smallerThanModulus() { g = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { g[0] %= (qElement[0] + 1) } } @@ -2174,8 +2148,8 @@ func genFull() gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { a := genRandomFq(genParams) - var carry uint64 - a[0], _ = bits.Add64(a[0], qElement[0], carry) + var carry uint32 + a[0], _ = bits.Add32(a[0], qElement[0], carry) genResult := gopter.NewGenResult(a, gopter.NoShrinker) return genResult diff --git a/internal/tinyfield/vector.go b/internal/tinyfield/vector.go index 6b045db8cd..db5a956511 100644 --- a/internal/tinyfield/vector.go +++ b/internal/tinyfield/vector.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 ConsenSys Software Inc. +// Copyright 2020-2024 Consensys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -108,7 +108,7 @@ func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { bstart := i * Bytes bend := bstart + Bytes b := bSlice[bstart:bend] - z[0] = binary.BigEndian.Uint64(b[0:8]) + z[0] = binary.BigEndian.Uint32(b[0:4]) if !z.smallerThanModulus() { atomic.AddUint64(&cptErrors, 1) @@ -185,43 +185,6 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} - func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/internal/tinyfield/vector_purego.go b/internal/tinyfield/vector_purego.go new file mode 100644 index 0000000000..22a2964d1f --- /dev/null +++ b/internal/tinyfield/vector_purego.go @@ -0,0 +1,43 @@ +// Copyright 2020-2024 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package tinyfield + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/internal/tinyfield/vector_test.go b/internal/tinyfield/vector_test.go index d17149d308..c548b40f59 100644 --- a/internal/tinyfield/vector_test.go +++ b/internal/tinyfield/vector_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 ConsenSys Software Inc. +// Copyright 2020-2024 Consensys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -328,17 +328,17 @@ func genVector(size int) gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { g := make(Vector, size) mixer := Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { mixer[0] %= (qElement[0] + 1) } for !mixer.smallerThanModulus() { mixer = Element{ - genParams.NextUint64(), + uint32(genParams.NextUint64()), } - if qElement[0] != ^uint64(0) { + if qElement[0] != ^uint32(0) { mixer[0] %= (qElement[0] + 1) } } From c324a2a38abb0696a70fc0c5dbf38c4b9d058561 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 16 Dec 2024 13:00:17 +0000 Subject: [PATCH 7/7] Revert "chore: update generated tinyfield due to dep update" This reverts commit 190c9569889c6f9a79de6d69bd32dc61e23bc57a. --- internal/tinyfield/doc.go | 12 +- internal/tinyfield/element.go | 192 ++++++++++++++++++--------- internal/tinyfield/element_purego.go | 70 ---------- internal/tinyfield/element_test.go | 52 ++++++-- internal/tinyfield/vector.go | 41 +++++- internal/tinyfield/vector_purego.go | 43 ------ internal/tinyfield/vector_test.go | 10 +- 7 files changed, 217 insertions(+), 203 deletions(-) delete mode 100644 internal/tinyfield/element_purego.go delete mode 100644 internal/tinyfield/vector_purego.go diff --git a/internal/tinyfield/doc.go b/internal/tinyfield/doc.go index 32f5ad92c0..a8b6fce697 100644 --- a/internal/tinyfield/doc.go +++ b/internal/tinyfield/doc.go @@ -1,19 +1,17 @@ -// Copyright 2020-2024 Consensys Software Inc. +// Copyright 2020-2024 ConsenSys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT // Package tinyfield contains field arithmetic operations for modulus = 0x2f. // -// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x). -// -// Additionally tinyfield.Vector offers an API to manipulate []Element. +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) // // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: // -// type Element [1]uint32 +// type Element [1]uint64 // // # Usage // @@ -40,7 +38,5 @@ // // # Warning // -// There is no security guarantees such as constant time implementation or side-channel attack resistance. -// This code is provided as-is. Partially audited, see https://github.com/Consensys/gnark/tree/master/audits -// for more details. +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package tinyfield diff --git a/internal/tinyfield/element.go b/internal/tinyfield/element.go index 96b7ad573f..5d7e45ae33 100644 --- a/internal/tinyfield/element.go +++ b/internal/tinyfield/element.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 Consensys Software Inc. +// Copyright 2020-2024 ConsenSys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/field/pool" ) -// Element represents a field element stored on 1 words (uint32) +// Element represents a field element stored on 1 words (uint64) // // Element are assumed to be in Montgomery form in all methods. // @@ -33,18 +33,18 @@ import ( // # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. -type Element [1]uint32 +type Element [1]uint64 const ( - Limbs = 1 // number of 32 bits words needed to represent a Element + Limbs = 1 // number of 64 bits words needed to represent a Element Bits = 6 // number of bits needed to represent a Element - Bytes = 4 // number of bytes needed to represent a Element + Bytes = 8 // number of bytes needed to represent a Element ) // Field modulus q const ( - q0 = 47 - q = q0 + q0 uint64 = 47 + q uint64 = q0 ) var qElement = Element{ @@ -63,7 +63,7 @@ func Modulus() *big.Int { // q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r // used for Montgomery reduction -const qInvNeg = 2558703921 +const qInvNeg uint64 = 12559485326780971313 func init() { _modulus.SetString("2f", 16) @@ -76,16 +76,16 @@ func init() { // var v Element // v.SetUint64(...) func NewElement(v uint64) Element { - z := Element{uint32(v % uint64(q0))} - z.toMont() + z := Element{v} + z.Mul(&z, &rSquare) return z } // SetUint64 sets z to v and returns z func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form - *z = Element{uint32(v % uint64(q0))} - return z.toMont() + *z = Element{v} + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -178,7 +178,7 @@ func (z *Element) SetZero() *Element { // SetOne z = 1 (in Montgomery form) func (z *Element) SetOne() *Element { - z[0] = 42 + z[0] = 25 return z } @@ -196,7 +196,7 @@ func (z *Element) Equal(x *Element) bool { } // NotEqual returns 0 if and only if z == x; constant-time -func (z *Element) NotEqual(x *Element) uint32 { +func (z *Element) NotEqual(x *Element) uint64 { return (z[0] ^ x[0]) } @@ -207,7 +207,7 @@ func (z *Element) IsZero() bool { // IsOne returns z == 1 func (z *Element) IsOne() bool { - return z[0] == 42 + return z[0] == 25 } // IsUint64 reports whether z can be represented as an uint64. @@ -217,7 +217,7 @@ func (z *Element) IsUint64() bool { // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - return uint64(z.Bits()[0]) + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -252,8 +252,8 @@ func (z *Element) LexicographicallyLargest() bool { _z := z.Bits() - var b uint32 - _, b = bits.Sub32(_z[0], 24, 0) + var b uint64 + _, b = bits.Sub64(_z[0], 24, 0) return b == 0 } @@ -292,7 +292,7 @@ func (z *Element) SetRandom() (*Element, error) { // Clear unused bits in in the most significant byte to increase probability // that the candidate is < q. bytes[k-1] &= uint8(int(1<> 1 @@ -338,31 +338,35 @@ func (z *Element) fromMont() *Element { // Add z = x + y (mod q) func (z *Element) Add(x, y *Element) *Element { - t := x[0] + y[0] - if t >= q { - t -= q + z[0], _ = bits.Add64(x[0], y[0], 0) + if z[0] >= q { + z[0] -= q } - z[0] = t return z } // Double z = x + x (mod q), aka Lsh 1 func (z *Element) Double(x *Element) *Element { - t := x[0] << 1 - if t >= q { - t -= q + if x[0]&(1<<63) == (1 << 63) { + // if highest bit is set, then we have a carry to x + x, we shift and subtract q + z[0] = (x[0] << 1) - q + } else { + // highest bit is not set, but x + x can still be >= q + z[0] = (x[0] << 1) + if z[0] >= q { + z[0] -= q + } } - z[0] = t return z } // Sub z = x - y (mod q) func (z *Element) Sub(x, y *Element) *Element { - t, b := bits.Sub32(x[0], y[0], 0) + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) if b != 0 { - t += q + z[0] += q } - z[0] = t return z } @@ -379,13 +383,69 @@ func (z *Element) Neg(x *Element) *Element { // Select is a constant-time conditional move. // If c=0, z = x0. Else z = x1 func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { - cC := uint32((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise z[0] = x0[0] ^ cC&(x0[0]^x1[0]) return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + var t [2]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + + t[1], D = bits.Add64(t[1], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + + t[0], C = bits.Add64(t[1], C, 0) + t[1], _ = bits.Add64(0, D, C) + + if t[1] != 0 { + // we need to reduce, we have a result on 2 words + z[0], _ = bits.Sub64(t[0], q0, 0) + return + } + + // copy t into z + z[0] = t[0] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } +} + func _fromMontGeneric(z *Element) { - z[0] = montReduce(uint64(z[0])) + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + z[0] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } } func _reduceGeneric(z *Element) { @@ -438,7 +498,7 @@ func _butterflyGeneric(a, b *Element) { // BitLen returns the minimum number of bits needed to represent z // returns 0 if z == 0 func (z *Element) BitLen() int { - return bits.Len32(z[0]) + return bits.Len64(z[0]) } // Hash msg to count prime field elements. @@ -505,15 +565,13 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // see section 2.3.2 of Tolga Acar's thesis // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf var rSquare = Element{ - 25, + 14, } // toMont converts z to Montgomery form // sets and returns z = z * r² func (z *Element) toMont() *Element { - const rBits = 32 - z[0] = uint32((uint64(z[0]) << rBits) % q) - return z + return z.Mul(z, &rSquare) } // String returns the decimal representation of z as generated by @@ -525,7 +583,7 @@ func (z *Element) String() string { // toBigInt returns z as a big.Int in Montgomery form func (z *Element) toBigInt(res *big.Int) *big.Int { var b [Bytes]byte - binary.BigEndian.PutUint32(b[0:4], z[0]) + binary.BigEndian.PutUint64(b[0:8], z[0]) return res.SetBytes(b[:]) } @@ -545,7 +603,7 @@ func (z *Element) Text(base int) string { const maxUint16 = 65535 zz := z.Bits() - return strconv.FormatUint(uint64(zz[0]), base) + return strconv.FormatUint(zz[0], base) } // BigInt sets and return z as a *big.Int @@ -563,10 +621,10 @@ func (z Element) ToBigIntRegular(res *big.Int) *big.Int { return z.toBigInt(res) } -// Bits provides access to z by returning its value as a little-endian [1]uint32 array. +// Bits provides access to z by returning its value as a little-endian [1]uint64 array. // Bits is intended to support implementation of missing low-level Element // functionality outside this package; it should be avoided otherwise. -func (z *Element) Bits() [1]uint32 { +func (z *Element) Bits() [1]uint64 { _z := *z fromMont(&_z) return _z @@ -615,8 +673,8 @@ func (z *Element) SetBytes(e []byte) *Element { return z } -// SetBytesCanonical interprets e as the bytes of a big-endian 4-byte integer. -// If e is not a 4-byte slice or encodes a value higher than q, +// SetBytesCanonical interprets e as the bytes of a big-endian 8-byte integer. +// If e is not a 8-byte slice or encodes a value higher than q, // SetBytesCanonical returns an error. func (z *Element) SetBytesCanonical(e []byte) error { if len(e) != Bytes { @@ -663,9 +721,19 @@ func (z *Element) SetBigInt(v *big.Int) *Element { // setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() - // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits - for i := 0; i < len(vBits); i++ { - z[i] = uint32(vBits[i]) + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = uint64(vBits[i]) + } else { + z[i/2] |= uint64(vBits[i]) << 32 + } + } } return z.toMont() @@ -764,11 +832,11 @@ var BigEndian bigEndian type bigEndian struct{} -// Element interpret b is a big-endian 4-byte slice. +// Element interpret b is a big-endian 8-byte slice. // If b encodes a value higher than q, Element returns error. func (bigEndian) Element(b *[Bytes]byte) (Element, error) { var z Element - z[0] = binary.BigEndian.Uint32((*b)[0:4]) + z[0] = binary.BigEndian.Uint64((*b)[0:8]) if !z.smallerThanModulus() { return Element{}, errors.New("invalid tinyfield.Element encoding") @@ -780,7 +848,7 @@ func (bigEndian) Element(b *[Bytes]byte) (Element, error) { func (bigEndian) PutElement(b *[Bytes]byte, e Element) { e.fromMont() - binary.BigEndian.PutUint32((*b)[0:4], e[0]) + binary.BigEndian.PutUint64((*b)[0:8], e[0]) } func (bigEndian) String() string { return "BigEndian" } @@ -792,7 +860,7 @@ type littleEndian struct{} func (littleEndian) Element(b *[Bytes]byte) (Element, error) { var z Element - z[0] = binary.LittleEndian.Uint32((*b)[0:4]) + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) if !z.smallerThanModulus() { return Element{}, errors.New("invalid tinyfield.Element encoding") @@ -804,7 +872,7 @@ func (littleEndian) Element(b *[Bytes]byte) (Element, error) { func (littleEndian) PutElement(b *[Bytes]byte, e Element) { e.fromMont() - binary.LittleEndian.PutUint32((*b)[0:4], e[0]) + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) } func (littleEndian) String() string { return "LittleEndian" } @@ -858,19 +926,19 @@ func (z *Element) Sqrt(x *Element) *Element { // if x == 0, sets and returns z = x func (z *Element) Inverse(x *Element) *Element { // Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" - const q uint32 = q0 + const q uint64 = q0 if x.IsZero() { z.SetZero() return z } - var r, s, u, v uint32 + var r, s, u, v uint64 u = q - s = 25 // s = r² + s = 14 // s = r² r = 0 v = x[0] - var carry, borrow uint32 + var carry, borrow uint64 for (u != 1) && (v != 1) { for v&1 == 0 { @@ -878,10 +946,10 @@ func (z *Element) Inverse(x *Element) *Element { if s&1 == 0 { s >>= 1 } else { - s, carry = bits.Add32(s, q, 0) + s, carry = bits.Add64(s, q, 0) s >>= 1 if carry != 0 { - s |= (1 << 31) + s |= (1 << 63) } } } @@ -890,22 +958,22 @@ func (z *Element) Inverse(x *Element) *Element { if r&1 == 0 { r >>= 1 } else { - r, carry = bits.Add32(r, q, 0) + r, carry = bits.Add64(r, q, 0) r >>= 1 if carry != 0 { - r |= (1 << 31) + r |= (1 << 63) } } } if v >= u { v -= u - s, borrow = bits.Sub32(s, r, 0) + s, borrow = bits.Sub64(s, r, 0) if borrow == 1 { s += q } } else { u -= v - r, borrow = bits.Sub32(r, s, 0) + r, borrow = bits.Sub64(r, s, 0) if borrow == 1 { r += q } diff --git a/internal/tinyfield/element_purego.go b/internal/tinyfield/element_purego.go deleted file mode 100644 index 301cd8589e..0000000000 --- a/internal/tinyfield/element_purego.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2020-2024 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package tinyfield - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - var y Element - y.SetUint64(3) - x.Mul(x, &y) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - var y Element - y.SetUint64(5) - x.Mul(x, &y) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y Element - y.SetUint64(13) - x.Mul(x, &y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} -func montReduce(v uint64) uint32 { - m := uint32(v) * qInvNeg - t := uint32((v + uint64(m)*q) >> 32) - if t >= q { - t -= q - } - return t -} - -// Mul z = x * y (mod q) -// -// x and y must be less than q -func (z *Element) Mul(x, y *Element) *Element { - v := uint64(x[0]) * uint64(y[0]) - z[0] = montReduce(v) - return z -} - -// Square z = x * x (mod q) -// -// x must be less than q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - v := uint64(x[0]) * uint64(x[0]) - z[0] = montReduce(v) - return z -} - -// Butterfly sets -// -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} diff --git a/internal/tinyfield/element_test.go b/internal/tinyfield/element_test.go index 7894b2a0cc..64d9667a54 100644 --- a/internal/tinyfield/element_test.go +++ b/internal/tinyfield/element_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 Consensys Software Inc. +// Copyright 2020-2024 ConsenSys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -192,7 +192,7 @@ func BenchmarkElementSqrt(b *testing.B) { func BenchmarkElementMul(b *testing.B) { x := Element{ - 25, + 14, } benchResElement.SetOne() b.ResetTimer() @@ -203,7 +203,7 @@ func BenchmarkElementMul(b *testing.B) { func BenchmarkElementCmp(b *testing.B) { x := Element{ - 25, + 14, } benchResElement = x benchResElement[0] = 0 @@ -921,6 +921,14 @@ func TestElementMul(t *testing.T) { c.Mul(&a.element, &r) d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + if c.BigInt(&e).Cmp(&d) != 0 { return false } @@ -943,6 +951,17 @@ func TestElementMul(t *testing.T) { genB, )) + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + specialValueTest := func() { // test special values against special values testValues := make([]Element, len(staticTestValues)) @@ -961,6 +980,13 @@ func TestElementMul(t *testing.T) { c.Mul(&a, &b) d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } @@ -2100,17 +2126,17 @@ func gen() gopter.Gen { var g testPairElement g.element = Element{ - uint32(genParams.NextUint64()), + genParams.NextUint64(), } - if qElement[0] != ^uint32(0) { + if qElement[0] != ^uint64(0) { g.element[0] %= (qElement[0] + 1) } for !g.element.smallerThanModulus() { g.element = Element{ - uint32(genParams.NextUint64()), + genParams.NextUint64(), } - if qElement[0] != ^uint32(0) { + if qElement[0] != ^uint64(0) { g.element[0] %= (qElement[0] + 1) } } @@ -2125,18 +2151,18 @@ func genRandomFq(genParams *gopter.GenParameters) Element { var g Element g = Element{ - uint32(genParams.NextUint64()), + genParams.NextUint64(), } - if qElement[0] != ^uint32(0) { + if qElement[0] != ^uint64(0) { g[0] %= (qElement[0] + 1) } for !g.smallerThanModulus() { g = Element{ - uint32(genParams.NextUint64()), + genParams.NextUint64(), } - if qElement[0] != ^uint32(0) { + if qElement[0] != ^uint64(0) { g[0] %= (qElement[0] + 1) } } @@ -2148,8 +2174,8 @@ func genFull() gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { a := genRandomFq(genParams) - var carry uint32 - a[0], _ = bits.Add32(a[0], qElement[0], carry) + var carry uint64 + a[0], _ = bits.Add64(a[0], qElement[0], carry) genResult := gopter.NewGenResult(a, gopter.NoShrinker) return genResult diff --git a/internal/tinyfield/vector.go b/internal/tinyfield/vector.go index db5a956511..6b045db8cd 100644 --- a/internal/tinyfield/vector.go +++ b/internal/tinyfield/vector.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 Consensys Software Inc. +// Copyright 2020-2024 ConsenSys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -108,7 +108,7 @@ func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { bstart := i * Bytes bend := bstart + Bytes b := bSlice[bstart:bend] - z[0] = binary.BigEndian.Uint32(b[0:4]) + z[0] = binary.BigEndian.Uint64(b[0:8]) if !z.smallerThanModulus() { atomic.AddUint64(&cptErrors, 1) @@ -185,6 +185,43 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/internal/tinyfield/vector_purego.go b/internal/tinyfield/vector_purego.go deleted file mode 100644 index 22a2964d1f..0000000000 --- a/internal/tinyfield/vector_purego.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2020-2024 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package tinyfield - -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - -// InnerProduct computes the inner product of two vectors. -// It panics if the vectors don't have the same length. -func (vector *Vector) InnerProduct(other Vector) (res Element) { - innerProductVecGeneric(&res, *vector, other) - return -} - -// Mul multiplies two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Mul(a, b Vector) { - mulVecGeneric(*vector, a, b) -} diff --git a/internal/tinyfield/vector_test.go b/internal/tinyfield/vector_test.go index c548b40f59..d17149d308 100644 --- a/internal/tinyfield/vector_test.go +++ b/internal/tinyfield/vector_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2024 Consensys Software Inc. +// Copyright 2020-2024 ConsenSys Software Inc. // Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -328,17 +328,17 @@ func genVector(size int) gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { g := make(Vector, size) mixer := Element{ - uint32(genParams.NextUint64()), + genParams.NextUint64(), } - if qElement[0] != ^uint32(0) { + if qElement[0] != ^uint64(0) { mixer[0] %= (qElement[0] + 1) } for !mixer.smallerThanModulus() { mixer = Element{ - uint32(genParams.NextUint64()), + genParams.NextUint64(), } - if qElement[0] != ^uint32(0) { + if qElement[0] != ^uint64(0) { mixer[0] %= (qElement[0] + 1) } }