diff --git a/backend/groth16/bn254/mpcsetup/marshal_test.go b/backend/groth16/bn254/mpcsetup/marshal_test.go deleted file mode 100644 index adbfc3fe0..000000000 --- a/backend/groth16/bn254/mpcsetup/marshal_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020-2024 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package mpcsetup - -/* TODO bring this back -func TestContributionSerialization(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode.") - } - assert := require.New(t) - - // Phase 1 - var srs1 Phase1 - srs1.Initialize(1 << 9) - srs1.Contribute() - - assert.NoError(gnarkio.RoundTripCheck(&srs1, func() interface{} { return new(Phase1) })) - - var myCircuit Circuit - ccs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &myCircuit) - assert.NoError(err) - - r1cs := ccs.(*cs.R1CS) - - // Phase 2 - srs2, _ := InitPhase2(r1cs, &srs1) - srs2.Contribute() - - assert.NoError(gnarkio.RoundTripCheck(&srs2, func() interface{} { return new(Phase2) })) -} -*/ diff --git a/backend/groth16/bn254/mpcsetup/phase2.go b/backend/groth16/bn254/mpcsetup/phase2.go index 5b64068c4..957a3f70a 100644 --- a/backend/groth16/bn254/mpcsetup/phase2.go +++ b/backend/groth16/bn254/mpcsetup/phase2.go @@ -34,7 +34,7 @@ type Phase2Evaluations struct { // TODO @Tabaie rename B []curve.G2Affine // B are the right coefficient polynomials for each witness element, evaluated at τ } PublicAndCommitmentCommitted [][]int - NbConstraints uint64 + NbConstraints uint64 // TODO unnecessary. len(Z) has that information (domain size) } type Phase2 struct { @@ -348,6 +348,7 @@ func VerifyPhase2(r1cs *cs.R1CS, commons *SrsCommons, beaconChallenge []byte, c func (p *Phase2) hash() []byte { sha := sha256.New() p.WriteTo(sha) + sha.Write(p.Challenge) return sha.Sum(nil) } diff --git a/backend/groth16/bn254/mpcsetup/setup.go b/backend/groth16/bn254/mpcsetup/setup.go index 739438c29..715ac7fc3 100644 --- a/backend/groth16/bn254/mpcsetup/setup.go +++ b/backend/groth16/bn254/mpcsetup/setup.go @@ -36,7 +36,7 @@ func (p *Phase2) Seal(commons *SrsCommons, evals *Phase2Evaluations, beaconChall ) // Initialize PK - pk.Domain = *fft.NewDomain(evals.NbConstraints) + pk.Domain = *fft.NewDomain(evals.NbConstraints) // TODO @Tabaie replace with len(Z)+1 pk.G1.Alpha.Set(&commons.G1.AlphaTau[0]) pk.G1.Beta.Set(&commons.G1.BetaTau[0]) pk.G1.Delta.Set(&p.Parameters.G1.Delta) diff --git a/backend/groth16/bn254/mpcsetup/setup_test.go b/backend/groth16/bn254/mpcsetup/setup_test.go index 5760d471b..d5714bca9 100644 --- a/backend/groth16/bn254/mpcsetup/setup_test.go +++ b/backend/groth16/bn254/mpcsetup/setup_test.go @@ -12,6 +12,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" cs "github.com/consensys/gnark/constraint/bn254" "io" + "slices" "sync" "testing" @@ -50,7 +51,7 @@ func testAll(t *testing.T, nbContributionsPhase1, nbContributionsPhase2 int) { bb.Reset() _, err := v.WriteTo(&bb) assert.NoError(err) - return bb.Bytes() + return slices.Clone(bb.Bytes()) } deserialize := func(v io.ReaderFrom, b []byte) { n, err := v.ReadFrom(bytes.NewReader(b)) diff --git a/backend/groth16/bn254/mpcsetup/unit_test.go b/backend/groth16/bn254/mpcsetup/unit_test.go index d8c0f0011..35cae3d23 100644 --- a/backend/groth16/bn254/mpcsetup/unit_test.go +++ b/backend/groth16/bn254/mpcsetup/unit_test.go @@ -8,6 +8,11 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark/backend/groth16" groth16Impl "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/constraint" + cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + gnarkio "github.com/consensys/gnark/io" "github.com/stretchr/testify/require" "math/big" "slices" @@ -106,7 +111,7 @@ func TestNoContributors(t *testing.T) { } func TestOnePhase1Contribute(t *testing.T) { - testAll(t, 1, 0) + testAll(t, 2, 0) } func TestUpdateCheck(t *testing.T) { @@ -416,3 +421,97 @@ func TestLinearCombinationsG2(t *testing.T) { frs(1, 3, 9, 27, 81), ) } + +func ones(N int) []fr.Element { + res := make([]fr.Element, N) + for i := range res { + res[i].SetOne() + } + return res +} + +func frs(x ...int) []fr.Element { + res := make([]fr.Element, len(x)) + for i := range res { + res[i].SetInt64(int64(x[i])) + } + return res +} + +func TestSerialization(t *testing.T) { + + testRoundtrip := func(_cs constraint.ConstraintSystem) { + var ( + p1 Phase1 + p2 Phase2 + ) + p1.Initialize(ecc.NextPowerOfTwo(uint64(_cs.GetNbConstraints()))) + commons := p1.Seal([]byte("beacon 1")) + + p2.Initialize(_cs.(*cs.R1CS), &commons) + p2.Contribute() + require.NoError(t, gnarkio.RoundTripCheck(&p2, func() interface{} { return new(Phase2) })) + } + + /*var p Phase2 + const b64 = "AACNaN0mCOtKUAD0aEvRP0h7pXctaB+w5Mwsb+skm2yDuPzlwTs+qCFf/3INR+fP/lHY6BLnqXyBjAIgCoPxOcSIEG0tcty/TAiaCN3lHCRacU+upLP+WpngByrrxbN9KrhmQLY3mhOHaV5Jo3W9pI2lTpLK9ZjkQpYKd92YCRKkJ9LyX3wqeYR4jQFf1mxtfJSNgluSZUUn3AoUSDmvh8m87TRh/JRcRZnq40BgnhkJ5nHs9siMSmhWGFjGgW/mOqpyrFoZEoK2rP+AT6ylkNGYxMmOBUj0meoeI2FB7RDqcuSxQOL1XK+Pm1dhxND33cykwpTF4oCrqQzSonxQGn+wFNzaYREOmkjCS9i12NbpXNyN2b9YpmujAL/GSD5LAwKNaN0mCOtKUAD0aEvRP0h7pXctaB+w5Mwsb+skm2yDuJ8HrqP1uckhSJCcTOeHMHyh0VqJtnoMhkRAWRPEWcsqIP3sH81riS5ARP1Pv172lVAmfoXnCzwFPNFPnvdSGFk=" + b, err := base64.StdEncoding.DecodeString(b64) + require.NoError(t, err) + n, err := p.ReadFrom(bytes.NewReader(b)) + require.NoError(t, err) + require.Equal(t, int64(len(b)), n)*/ + + _cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &tinyCircuit{}) + require.NoError(t, err) + testRoundtrip(_cs) + + testRoundtrip(getTestCircuit(t)) +} + +type tinyCircuit struct { + X [4]frontend.Variable `gnark:",public"` +} + +func (c *tinyCircuit) Define(api frontend.API) error { + for i := range c.X { + api.AssertIsEqual(c.X[i], i) + } + return nil +} + +func (p *Phase2) Equal(o *Phase2) bool { + + if p.Parameters.G2.Delta != o.Parameters.G2.Delta { + print("g2 delta") + } + + if p.Delta != o.Delta { + print("proof delta") + } + + if p.Parameters.G1.Delta != o.Parameters.G1.Delta { + print("g1 delta") + } + + return p.Parameters.G2.Delta == o.Parameters.G2.Delta && + slices.Equal(p.Sigmas, o.Sigmas) && + // bytes.Equal(p.Challenge, o.Challenge) && This function is used in serialization round-trip testing, and we deliberately don't write the challenges + p.Delta == o.Delta && + sliceSliceEqual(p.Parameters.G1.SigmaCKK, o.Parameters.G1.SigmaCKK) && + p.Parameters.G1.Delta == o.Parameters.G1.Delta && + slices.Equal(p.Parameters.G1.Z, o.Parameters.G1.Z) && + slices.Equal(p.Parameters.G1.PKK, o.Parameters.G1.PKK) && + slices.Equal(p.Parameters.G2.Sigma, o.Parameters.G2.Sigma) +} + +func sliceSliceEqual[T comparable](a, b [][]T) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !slices.Equal(a[i], b[i]) { + return false + } + } + return true +} diff --git a/go.mod b/go.mod index fbf26d48f..f5b5106d5 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/ronanh/intcomp v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 - golang.org/x/crypto v0.26.0 + golang.org/x/crypto v0.31.0 golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 golang.org/x/sync v0.8.0 ) @@ -33,7 +33,9 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/sys v0.24.0 // indirect + golang.org/x/sys v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/consensys/gnark-crypto => /Users/arya/gnark-crypto diff --git a/go.sum b/go.sum index e4a044940..51a0a31e8 100644 --- a/go.sum +++ b/go.sum @@ -308,6 +308,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -466,6 +468,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= diff --git a/io/roundtrip.go b/io/roundtrip.go index ecac78afa..66d38aa7c 100644 --- a/io/roundtrip.go +++ b/io/roundtrip.go @@ -3,6 +3,7 @@ package io import ( "bytes" "errors" + "fmt" "io" "reflect" ) @@ -21,8 +22,8 @@ func RoundTripCheck(from any, to func() any) error { if err != nil { return err } - if !reflect.DeepEqual(from, r) { - return errors.New("reconstructed object don't match original (ReadFrom)") + if err = equal(from, r); err != nil { + return fmt.Errorf("ReadFrom: %w", err) } if written != read { return errors.New("bytes written / read don't match") @@ -35,8 +36,8 @@ func RoundTripCheck(from any, to func() any) error { if err != nil { return err } - if !reflect.DeepEqual(from, r) { - return errors.New("reconstructed object don't match original (UnsafeReadFrom)") + if err = equal(from, r); err != nil { + return fmt.Errorf("UnsafeReadFrom: %w", err) } if written != read { return errors.New("bytes written / read don't match") @@ -52,6 +53,8 @@ func RoundTripCheck(from any, to func() any) error { return err } + //fmt.Println(base64.StdEncoding.EncodeToString(buf.Bytes()[:written])) + if err := reconstruct(written); err != nil { return err } @@ -85,8 +88,28 @@ func DumpRoundTripCheck(from any, to func() any) error { if err := r.ReadDump(bytes.NewReader(buf.Bytes())); err != nil { return err } - if !reflect.DeepEqual(from, r) { - return errors.New("reconstructed object don't match original (ReadDump)") + if err := equal(from, r); err != nil { + return fmt.Errorf("ReadDump: %w", err) } return nil } + +func equal(a, b any) error { + // check for a custom Equal method + aV := reflect.ValueOf(a) + eq := aV.MethodByName("Equal") + if eq.IsValid() { + res := eq.Call([]reflect.Value{reflect.ValueOf(b)}) + if len(res) != 1 { + return errors.New("`Equal` method must return a single bool") + } + if res[0].Bool() { + return nil + } + return errors.New("reconstructed object does not match the original (custom Equal)") + } + if reflect.DeepEqual(a, b) { + return nil + } + return errors.New("reconstructed object does not match the original (reflect.DeepEqual)") +}