diff --git a/.gotestfmt/downloads.gotpl b/.gotestfmt/downloads.gotpl deleted file mode 100644 index ca1cf92f55..0000000000 --- a/.gotestfmt/downloads.gotpl +++ /dev/null @@ -1,36 +0,0 @@ -{{- /*gotype: github.com/gotesttools/gotestfmt/v2/parser.Downloads*/ -}} -{{- /* -This template contains the format for a package download. -*/ -}} -{{- $settings := .Settings -}} -{{- if or .Packages .Reason -}} - {{- if or (not .Settings.HideSuccessfulDownloads) .Failed -}} - {{- if .Failed -}} - ❌ - {{- else -}} - 📥 - {{- end -}} - {{ " " }} Dependency downloads - {{ "\n" -}} - - {{- range .Packages -}} - {{- if or (not $settings.HideSuccessfulDownloads) .Failed -}} - {{- " " -}} - {{- if .Failed -}} - ❌ - {{- else -}} - 📦 - {{- end -}} - {{- " " -}} - {{- .Package }} {{ .Version -}} - {{- "\n" -}} - {{ with .Reason -}} - {{- " " -}}{{ . -}}{{ "\n" -}} - {{- end -}} - {{- end -}} - {{- end -}} - {{- with .Reason -}} - {{- " " -}}🛑 {{ . }}{{ "\n" -}} - {{- end -}} - {{- end -}} -{{- end -}} diff --git a/.gotestfmt/package.gotpl b/.gotestfmt/package.gotpl deleted file mode 100644 index 504949a86b..0000000000 --- a/.gotestfmt/package.gotpl +++ /dev/null @@ -1,42 +0,0 @@ -{{- /*gotype: github.com/gotesttools/gotestfmt/v2/parser.Package*/ -}} - -{{- $settings := .Settings -}} -{{- if and (or (not $settings.HideSuccessfulPackages) (ne .Result "PASS")) (or (not $settings.HideEmptyPackages) (ne .Result "SKIP") (ne (len .TestCases) 0)) -}} - 📦 `{{ .Name }}` - {{- with .Coverage -}} - ({{ . }}% coverage) - {{- end -}} - {{- "\n" -}} - {{- with .Reason -}} - {{- " " -}}🛑 {{ . -}}{{- "\n" -}} - {{- end -}} - {{- with .Output -}} - ```{{- "\n" -}} - {{- . -}}{{- "\n" -}} - ```{{- "\n" -}} - {{- end -}} - {{- with .TestCases -}} - {{- range . -}} - {{- if or (not $settings.HideSuccessfulTests) (ne .Result "PASS") -}} - {{- if eq .Result "PASS" -}} - ✅ - {{- else if eq .Result "SKIP" -}} - 🚧 - {{- else -}} - ❌ - {{- end -}} - {{ " " }}`{{- .Name -}}` {{ .Duration -}} - {{- "\n" -}} - - {{- with .Output -}} - ```{{- "\n" -}} - {{- formatTestOutput . $settings -}}{{- "\n" -}} - ```{{- "\n" -}} - {{- end -}} - - {{- "\n" -}} - {{- end -}} - {{- end -}} - {{- end -}} - {{- "\n" -}} -{{- end -}} diff --git a/internal/parallel/execute.go b/internal/parallel/execute.go new file mode 100644 index 0000000000..05f9a8f666 --- /dev/null +++ b/internal/parallel/execute.go @@ -0,0 +1,56 @@ +package parallel + +import ( + "runtime" + "sync" +) + +// Execute process in parallel the work function +func Execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/std/fiat-shamir/settings.go b/std/fiat-shamir/settings.go index 146a64355e..2c475e83e4 100644 --- a/std/fiat-shamir/settings.go +++ b/std/fiat-shamir/settings.go @@ -3,6 +3,9 @@ package fiatshamir import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/emulated" + gohash "hash" + "math/big" ) type Settings struct { @@ -12,6 +15,20 @@ type Settings struct { Hash hash.FieldHasher } +type SettingsBigInt struct { + Transcript *Transcript + Prefix string + BaseChallenges []big.Int + Hash gohash.Hash +} + +type SettingsEmulated[FR emulated.FieldParams] struct { + Transcript *Transcript + Prefix string + BaseChallenges []emulated.Element[FR] + Hash hash.FieldHasher +} + func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...frontend.Variable) Settings { return Settings{ Transcript: transcript, @@ -20,9 +37,39 @@ func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...fro } } +func WithTranscriptBigInt(transcript *Transcript, prefix string, baseChallenges ...big.Int) SettingsBigInt { + return SettingsBigInt{ + Transcript: transcript, + Prefix: prefix, + BaseChallenges: baseChallenges, + } +} + +func WithTranscriptFr[FR emulated.FieldParams](transcript *Transcript, prefix string, baseChallenges ...emulated.Element[FR]) SettingsEmulated[FR] { + return SettingsEmulated[FR]{ + Transcript: transcript, + Prefix: prefix, + BaseChallenges: baseChallenges, + } +} + func WithHash(hash hash.FieldHasher, baseChallenges ...frontend.Variable) Settings { return Settings{ BaseChallenges: baseChallenges, Hash: hash, } } + +func WithHashBigInt(hash gohash.Hash, baseChallenges ...big.Int) SettingsBigInt { + return SettingsBigInt{ + BaseChallenges: baseChallenges, + Hash: hash, + } +} + +func WithHashFr[FR emulated.FieldParams](hash hash.FieldHasher, baseChallenges ...emulated.Element[FR]) SettingsEmulated[FR] { + return SettingsEmulated[FR]{ + BaseChallenges: baseChallenges, + Hash: hash, + } +} diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index a715a9d98e..4da2629934 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -308,6 +308,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, claims := newClaimsManager(c, assignment) var firstChallenge []frontend.Variable + // why no bind values here? firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) if err != nil { return err @@ -327,7 +328,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(finalEvalProof) != 0 || len(proofW.RoundPolyEvaluations) != 0 { return fmt.Errorf("no proof allowed for input wire with a single claim") } @@ -470,16 +471,16 @@ func (a WireAssignment) NumVars() int { func (p Proof) Serialize() []frontend.Variable { size := 0 for i := range p { - for j := range p[i].PartialSumPolys { - size += len(p[i].PartialSumPolys[j]) + for j := range p[i].RoundPolyEvaluations { + size += len(p[i].RoundPolyEvaluations[j]) } size += len(p[i].FinalEvalProof.([]frontend.Variable)) } res := make([]frontend.Variable, 0, size) for i := range p { - for j := range p[i].PartialSumPolys { - res = append(res, p[i].PartialSumPolys[j]...) + for j := range p[i].RoundPolyEvaluations { + res = append(res, p[i].RoundPolyEvaluations[j]...) } res = append(res, p[i].FinalEvalProof.([]frontend.Variable)...) } @@ -519,9 +520,9 @@ func DeserializeProof(sorted []*Wire, serializedProof []frontend.Variable) (Proo reader := variablesReader(serializedProof) for i, wI := range sorted { if !wI.noProof() { - proof[i].PartialSumPolys = make([]polynomial.Polynomial, logNbInstances) - for j := range proof[i].PartialSumPolys { - proof[i].PartialSumPolys[j] = reader.nextN(wI.Gate.Degree() + 1) + proof[i].RoundPolyEvaluations = make([]polynomial.Polynomial, logNbInstances) + for j := range proof[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[j] = reader.nextN(wI.Gate.Degree() + 1) } } proof[i].FinalEvalProof = reader.nextN(wI.nbUniqueInputs()) diff --git a/std/gkr/gkr_test.go b/std/gkr/gkr_test.go index d24b25a95c..31ebd5c112 100644 --- a/std/gkr/gkr_test.go +++ b/std/gkr/gkr_test.go @@ -3,6 +3,7 @@ package gkr import ( "encoding/json" "fmt" + "math/big" "os" "path/filepath" "reflect" @@ -165,8 +166,8 @@ type TestCase struct { type TestCaseInfo struct { Hash HashDescription `json:"hash"` Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` + Input [][]big.Int `json:"input"` + Output [][]big.Int `json:"output"` Proof PrintableProof `json:"proof"` } @@ -275,8 +276,8 @@ func (g _select) Degree() int { type PrintableProof []PrintableSumcheckProof type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` + RoundPolyEvaluations [][]interface{} `json:"roundPolyEvaluations"` } func unmarshalProof(printable PrintableProof) (proof Proof) { @@ -294,9 +295,9 @@ func unmarshalProof(printable PrintableProof) (proof Proof) { proof[i].FinalEvalProof = nil } - proof[i].PartialSumPolys = make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)) - for k := range printable[i].PartialSumPolys { - proof[i].PartialSumPolys[k] = ToVariableSlice(printable[i].PartialSumPolys[k]) + proof[i].RoundPolyEvaluations = make([]polynomial.Polynomial, len(printable[i].RoundPolyEvaluations)) + for k := range printable[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[k] = ToVariableSlice(printable[i].RoundPolyEvaluations[k]) } } return @@ -327,7 +328,6 @@ func TestLoadCircuit(t *testing.T) { assert.Equal(t, []*Wire{}, c[0].Inputs) assert.Equal(t, []*Wire{&c[0]}, c[1].Inputs) assert.Equal(t, []*Wire{&c[1]}, c[2].Inputs) - } func TestTopSortTrivial(t *testing.T) { diff --git a/std/gkr/test_vectors/single_identity_gate_two_instances.json b/std/gkr/test_vectors/single_identity_gate_two_instances.json index ce326d0a63..fa38a03cb6 100644 --- a/std/gkr/test_vectors/single_identity_gate_two_instances.json +++ b/std/gkr/test_vectors/single_identity_gate_two_instances.json @@ -19,13 +19,13 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 5 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -8 diff --git a/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json index 2c95f044f2..a995f7197a 100644 --- a/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json +++ b/std/gkr/test_vectors/single_input_two_identity_gates_two_instances.json @@ -23,7 +23,7 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ 0, 0 @@ -34,7 +34,7 @@ "finalEvalProof": [ 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -16 @@ -45,7 +45,7 @@ "finalEvalProof": [ 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -16 diff --git a/std/gkr/test_vectors/single_input_two_outs_two_instances.json b/std/gkr/test_vectors/single_input_two_outs_two_instances.json index d348303d0e..6dace72193 100644 --- a/std/gkr/test_vectors/single_input_two_outs_two_instances.json +++ b/std/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -23,7 +23,7 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ 0, 0 @@ -34,7 +34,7 @@ "finalEvalProof": [ 0 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -4, -36, @@ -46,7 +46,7 @@ "finalEvalProof": [ 0 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -2, -12 diff --git a/std/gkr/test_vectors/single_mimc_gate_four_instances.json b/std/gkr/test_vectors/single_mimc_gate_four_instances.json index 525459ecb1..1162e56f36 100644 --- a/std/gkr/test_vectors/single_mimc_gate_four_instances.json +++ b/std/gkr/test_vectors/single_mimc_gate_four_instances.json @@ -29,18 +29,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ -1, -3 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -32640, -2239484, diff --git a/std/gkr/test_vectors/single_mimc_gate_two_instances.json b/std/gkr/test_vectors/single_mimc_gate_two_instances.json index 7fa23ce4b1..12d7755dd5 100644 --- a/std/gkr/test_vectors/single_mimc_gate_two_instances.json +++ b/std/gkr/test_vectors/single_mimc_gate_two_instances.json @@ -23,18 +23,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 1, 0 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -2187, -65536, diff --git a/std/gkr/test_vectors/single_mul_gate_two_instances.json b/std/gkr/test_vectors/single_mul_gate_two_instances.json index 75c1d59c3d..ba854e37f5 100644 --- a/std/gkr/test_vectors/single_mul_gate_two_instances.json +++ b/std/gkr/test_vectors/single_mul_gate_two_instances.json @@ -23,18 +23,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 5, 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -9, -32, diff --git a/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json index 10e5f1ff3c..e145c7d18d 100644 --- a/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json +++ b/std/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -19,13 +19,13 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ 3 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -1, 0 @@ -36,7 +36,7 @@ "finalEvalProof": [ 3 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -1, 0 diff --git a/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json index 19e127df71..e972222802 100644 --- a/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json +++ b/std/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json @@ -23,18 +23,18 @@ "proof": [ { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [], - "partialSumPolys": [] + "roundPolyEvaluations": [] }, { "finalEvalProof": [ -1, 1 ], - "partialSumPolys": [ + "roundPolyEvaluations": [ [ -3, -16 diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 6c1f19b04d..9dbc471e25 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -245,7 +245,7 @@ func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) { constLimbs := make([]*big.Int, len(v.Limbs)) for i, l := range v.Limbs { // for each limb we get it's constant value if we can, or fail. - if constLimbs[i], ok = f.api.ConstantValue(l); !ok { + if constLimbs[i], ok = f.api.Compiler().ConstantValue(l); !ok { return nil, false } } diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 5c2c700663..ac20b22b0b 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -34,6 +34,7 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) { func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { f.enforceWidthConditional(a) f.enforceWidthConditional(b) + ba, aConst := f.constantValue(a) bb, bConst := f.constantValue(b) if aConst && bConst { diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 278b9a5024..5177873adf 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -115,47 +115,98 @@ func (mc *mulCheck[T]) cleanEvaluations() { // mulMod returns a*b mod r. In practice it computes the result using a hint and // defers the actual multiplication check. func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { - f.enforceWidthConditional(a) - f.enforceWidthConditional(b) - f.enforceWidthConditional(p) - k, r, c, err := f.callMulHint(a, b, true, p) - if err != nil { - panic(err) - } - mc := mulCheck[T]{ - f: f, - a: a, - b: b, - c: c, - k: k, - r: r, - p: p, - } - f.mulChecks = append(f.mulChecks, mc) - return r + return f.mulModProfiling(a, b, p, true) + // f.enforceWidthConditional(a) + // f.enforceWidthConditional(b) + // f.enforceWidthConditional(p) + // k, r, c, err := f.callMulHint(a, b, true, p) + // if err != nil { + // panic(err) + // } + // mc := mulCheck[T]{ + // f: f, + // a: a, + // b: b, + // c: c, + // k: k, + // r: r, + // p: p, + // } + // f.mulChecks = append(f.mulChecks, mc) + // return r } // checkZero creates multiplication check a * 1 = 0 + k*p. func (f *Field[T]) checkZero(a *Element[T], p *Element[T]) { + f.mulModProfiling(a, f.shortOne(), p, false) // the method works similarly to mulMod, but we know that we are multiplying // by one and expected result should be zero. + // f.enforceWidthConditional(a) + // f.enforceWidthConditional(p) + // b := f.shortOne() + // k, r, c, err := f.callMulHint(a, b, false, p) + // if err != nil { + // panic(err) + // } + // mc := mulCheck[T]{ + // f: f, + // a: a, + // b: b, // one on single limb to speed up the polynomial evaluation + // c: c, + // k: k, + // r: r, // expected to be zero on zero limbs. + // p: p, + // } + // f.mulChecks = append(f.mulChecks, mc) +} + +func (f *Field[T]) mulModProfiling(a, b *Element[T], p *Element[T], isMulMod bool) *Element[T] { f.enforceWidthConditional(a) - f.enforceWidthConditional(p) - b := f.shortOne() - k, r, c, err := f.callMulHint(a, b, false, p) + f.enforceWidthConditional(b) + k, r, c, err := f.callMulHint(a, b, isMulMod, p) if err != nil { panic(err) } mc := mulCheck[T]{ f: f, a: a, - b: b, // one on single limb to speed up the polynomial evaluation + b: b, c: c, k: k, - r: r, // expected to be zero on zero limbs. - p: p, + r: r, } - f.mulChecks = append(f.mulChecks, mc) + var toCommit []frontend.Variable + toCommit = append(toCommit, mc.a.Limbs...) + toCommit = append(toCommit, mc.b.Limbs...) + toCommit = append(toCommit, mc.r.Limbs...) + toCommit = append(toCommit, mc.k.Limbs...) + toCommit = append(toCommit, mc.c.Limbs...) + multicommit.WithCommitment(f.api, func(api frontend.API, commitment frontend.Variable) error { + // we do nothing. We just want to ensure that we count the commitments + return nil + }, toCommit...) + // XXX: or use something variable to count the commitments and constraints properly. Maybe can use 123 from hint? + commitment := 123 + + // for efficiency, we compute all powers of the challenge as slice at. + coefsLen := max(len(mc.a.Limbs), len(mc.b.Limbs), + len(mc.c.Limbs), len(mc.k.Limbs)) + at := make([]frontend.Variable, coefsLen) + at[0] = commitment + for i := 1; i < len(at); i++ { + at[i] = f.api.Mul(at[i-1], commitment) + } + mc.evalRound1(at) + mc.evalRound2(at) + // evaluate p(X) at challenge + pval := f.evalWithChallenge(f.Modulus(), at) + // compute (2^t-X) at challenge + coef := big.NewInt(1) + coef.Lsh(coef, f.fParams.BitsPerLimb()) + ccoef := f.api.Sub(coef, commitment) + // verify all mulchecks + mc.check(f.api, pval.evaluation, ccoef) + return r } // evalWithChallenge represents element a as a polynomial a(X) and evaluates at diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index a9f0d9cda3..2ef1f26889 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -3,10 +3,10 @@ package emulated import ( "errors" "fmt" - "math/bits" - "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/selector" + "math/big" + "math/bits" ) // Div computes a/b and returns it. It uses [DivHint] as a hint function. @@ -368,3 +368,37 @@ type overflowError struct { func (e overflowError) Error() string { return fmt.Sprintf("op %s overflow %d exceeds max %d", e.op, e.nextOverflow, e.maxOverflow) } + +func (f *Field[T]) String(a *Element[T]) string { + // for debug only, if is not test engine then no-op + var fp T + blimbs := make([]*big.Int, len(a.Limbs)) + for i, v := range a.Limbs { + switch vv := v.(type) { + case *big.Int: + blimbs[i] = vv + case big.Int: + blimbs[i] = &vv + case int: + blimbs[i] = new(big.Int) + blimbs[i].SetInt64(int64(vv)) + case uint: + blimbs[i] = new(big.Int) + blimbs[i].SetUint64(uint64(vv)) + default: + return "???" + } + } + res := new(big.Int) + err := recompose(blimbs, fp.BitsPerLimb(), res) + if err != nil { + return "!!!" + } + reduced := new(big.Int).Mod(res, fp.Modulus()) + return reduced.String() +} + +func (f *Field[T]) Println(a *Element[T]) { + res := f.String(a) + fmt.Println(res) +} diff --git a/std/math/polynomial/polynomial.go b/std/math/polynomial/polynomial.go index e09ef69ef1..511dc4107d 100644 --- a/std/math/polynomial/polynomial.go +++ b/std/math/polynomial/polynomial.go @@ -22,6 +22,10 @@ type Univariate[FR emulated.FieldParams] []emulated.Element[FR] // coefficients. type Multilinear[FR emulated.FieldParams] []emulated.Element[FR] +func (ml *Multilinear[FR]) NumVars() int { + return bits.Len(uint(len(*ml) - 1)) +} + func valueOf[FR emulated.FieldParams](univ []*big.Int) []emulated.Element[FR] { ret := make([]emulated.Element[FR], len(univ)) for i := range univ { @@ -61,6 +65,9 @@ type Polynomial[FR emulated.FieldParams] struct { // FromSlice maps slice of emulated element values to their references. func FromSlice[FR emulated.FieldParams](in []emulated.Element[FR]) []*emulated.Element[FR] { + if len(in) == 0 { + return []*emulated.Element[FR]{} + } r := make([]*emulated.Element[FR], len(in)) for i := range in { r[i] = &in[i] diff --git a/std/polynomial/polynomial.go b/std/polynomial/polynomial.go index 0953cb3ac7..4bd0940023 100644 --- a/std/polynomial/polynomial.go +++ b/std/polynomial/polynomial.go @@ -3,6 +3,7 @@ package polynomial import ( "math/bits" + "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/frontend" ) @@ -11,6 +12,53 @@ type MultiLin []frontend.Variable var minFoldScaledLogSize = 16 +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + +// Evaluate assumes len(m) = 1 << len(at) +// it doesn't modify m +func (m MultiLin) EvaluatePool(api frontend.API, at []frontend.Variable, pool *Pool) frontend.Variable { + _m := _clone(m, pool) + + /*minFoldScaledLogSize := 16 + if api is r1cs { + minFoldScaledLogSize = math.MaxInt64 // no scaling for r1cs + }*/ + + scaleCorrectionFactor := frontend.Variable(1) + // at each iteration fold by at[i] + for len(_m) > 1 { + if len(_m) >= minFoldScaledLogSize { + scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0])) + } else { + _m.Fold(api, at[0]) + } + _m = _m[:len(_m)/2] + at = at[1:] + } + + if len(at) != 0 { + panic("incompatible evaluation vector size") + } + + result := _m[0] + + _dump(_m, pool) + + return api.Mul(result, scaleCorrectionFactor) +} + // Evaluate assumes len(m) = 1 << len(at) // it doesn't modify m func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Variable { @@ -27,7 +75,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va if len(_m) >= minFoldScaledLogSize { scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0])) } else { - _m.fold(api, at[0]) + _m.Fold(api, at[0]) } _m = _m[:len(_m)/2] at = at[1:] @@ -42,7 +90,7 @@ func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Va // fold fixes the value of m's first variable to at, thus halving m's required bookkeeping table size // WARNING: The user should halve m themselves after the call -func (m MultiLin) fold(api frontend.API, at frontend.Variable) { +func (m MultiLin) Fold(api frontend.API, at frontend.Variable) { zero := m[:len(m)/2] one := m[len(m)/2:] for j := range zero { @@ -51,6 +99,43 @@ func (m MultiLin) fold(api frontend.API, at frontend.Variable) { } } +func (m *MultiLin) FoldParallel(api frontend.API, r frontend.Variable) utils.Task { + mid := len(*m) / 2 + bottom, top := (*m)[:mid], (*m)[mid:] + + *m = bottom + + return func(start, end int) { + var t frontend.Variable // no need to update the top part + for i := start; i < end; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t = api.Sub(&top[i], &bottom[i]) + t = api.Mul(&t, &r) + bottom[i] = api.Add(&bottom[i], &t) + } + } +} + +// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] +func (m *MultiLin) Eq(api frontend.API, q []frontend.Variable) { + n := len(q) + + if len(*m) != 1< p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large +} + +func (p *Pool) Make(n int) []frontend.Variable { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]frontend.Variable) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse.Load(ptr); ok { + p.inUse.Delete(ptr) + metadata.(inUseData).pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } + } +} + +func (p *Pool) addInUse(ptr *frontend.Variable, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) + + if prevPcs, ok := p.inUse.Load(ptr); ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.(inUseData).allocatedFor))) + } + p.inUse.Store(ptr, inUseData{ + allocatedFor: pcs[:n], + pool: pool, + }) +} + +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) +} + +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + p.inUse.Range(func(_, pcs any) bool { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.(inUseData).allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) + } + return true + }) +} + +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) make(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []frontend.Variable) *frontend.Variable { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*frontend.Variable)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse + } + + stats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(stats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []frontend.Variable) []frontend.Variable { + res := p.Make(len(slice)) + copy(res, slice) + return res +} diff --git a/std/recursion/gkr/gkr_nonnative.go b/std/recursion/gkr/gkr_nonnative.go new file mode 100644 index 0000000000..191fec699f --- /dev/null +++ b/std/recursion/gkr/gkr_nonnative.go @@ -0,0 +1,1259 @@ +package gkrnonative + +import ( + "fmt" + cryptofiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/parallel" + fiatshamir "github.com/consensys/gnark/std/fiat-shamir" + "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/polynomial" + "github.com/consensys/gnark/std/recursion" + "github.com/consensys/gnark/std/recursion/sumcheck" + "math/big" + "slices" + "strconv" +) + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(*sumcheck.BigIntEngine, ...*big.Int) *big.Int + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +// Gate must be a low-degree polynomial +type GateEmulated[FR emulated.FieldParams] interface { + Evaluate(*sumcheck.EmuEngine[FR], ...*emulated.Element[FR]) *emulated.Element[FR] + Degree() int +} + +type WireEmulated[FR emulated.FieldParams] struct { + Gate GateEmulated[FR] + Inputs []*WireEmulated[FR] // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) nbUniqueInputs() int { + set := make(map[*Wire]struct{}, len(w.Inputs)) + for _, in := range w.Inputs { + set[in] = struct{}{} + } + return len(set) +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = utils.Max(res, c[i].Gate.Degree()) + } + } + return res +} + +type CircuitEmulated[FR emulated.FieldParams] []WireEmulated[FR] + +func (w WireEmulated[FR]) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w WireEmulated[FR]) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w WireEmulated[FR]) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w WireEmulated[FR]) nbUniqueInputs() int { + set := make(map[*WireEmulated[FR]]struct{}, len(w.Inputs)) + for _, in := range w.Inputs { + set[in] = struct{}{} + } + return len(set) +} + +func (w WireEmulated[FR]) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]sumcheck.NativeMultilinear + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignmentEmulated[FR emulated.FieldParams] map[*WireEmulated[FR]]polynomial.Multilinear[FR] + +type Proofs[FR emulated.FieldParams] []sumcheck.Proof[FR] // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaimsEmulated[FR emulated.FieldParams] struct { + wire *WireEmulated[FR] + evaluationPoints [][]emulated.Element[FR] + claimedEvaluations []emulated.Element[FR] + manager *claimsManagerEmulated[FR] // WARNING: Circular references + verifier *GKRVerifier[FR] + engine *sumcheck.EmuEngine[FR] +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) NbClaims() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) NbVars() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) CombinedSum(a *emulated.Element[FR]) *emulated.Element[FR] { + evalsAsPoly := polynomial.Univariate[FR](e.claimedEvaluations) + return e.verifier.p.EvalUnivariate(evalsAsPoly, a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]) AssertEvaluation(r []*emulated.Element[FR], combinationCoeff, expectedValue *emulated.Element[FR], proof sumcheck.EvaluationProof) error { + inputEvaluationsNoRedundancy := proof.([]emulated.Element[FR]) + field, err := emulated.NewField[FR](e.verifier.api) + if err != nil { + return fmt.Errorf("failed to create field: %w", err) + } + p, err := polynomial.New[FR](e.verifier.api) + if err != nil { + return err + } + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := p.EvalEqual(polynomial.FromSlice(e.evaluationPoints[numClaims-1]), r) + for i := numClaims - 2; i >= 0; i-- { + evaluation = field.Mul(evaluation, combinationCoeff) + eq := p.EvalEqual(polynomial.FromSlice(e.evaluationPoints[i]), r) + evaluation = field.Add(evaluation, eq) + } + + // the g(...) term + var gateEvaluation emulated.Element[FR] + if e.wire.IsInput() { + gateEvaluationPtr, err := p.EvalMultilinear(r, e.manager.assignment[e.wire]) + if err != nil { + return err + } + gateEvaluation = *gateEvaluationPtr + } else { + inputEvaluations := make([]emulated.Element[FR], len(e.wire.Inputs)) + indexesInProof := make(map[*WireEmulated[FR]]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, polynomial.FromSliceReferences(r), inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = *e.wire.Gate.Evaluate(e.engine, polynomial.FromSlice(inputEvaluations)...) + } + + evaluation = field.Mul(evaluation, &gateEvaluation) + + field.AssertIsEqual(evaluation, expectedValue) + return nil +} + +type claimsManagerEmulated[FR emulated.FieldParams] struct { + claimsMap map[*WireEmulated[FR]]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR] + assignment WireAssignmentEmulated[FR] +} + +func newClaimsManagerEmulated[FR emulated.FieldParams](c CircuitEmulated[FR], assignment WireAssignmentEmulated[FR], verifier GKRVerifier[FR]) (claims claimsManagerEmulated[FR]) { + claims.assignment = assignment + claims.claimsMap = make(map[*WireEmulated[FR]]*eqTimesGateEvalSumcheckLazyClaimsEmulated[FR], len(c)) + engine, err := sumcheck.NewEmulatedEngine[FR](verifier.api) + if err != nil { + panic(err) + } + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaimsEmulated[FR]{ + wire: wire, + evaluationPoints: make([][]emulated.Element[FR], 0, wire.NbClaims()), + claimedEvaluations: make(polynomial.Multilinear[FR], wire.NbClaims()), + manager: &claims, + verifier: &verifier, + engine: engine, + } + } + return +} + +func (m *claimsManagerEmulated[FR]) add(wire *WireEmulated[FR], evaluationPoint []emulated.Element[FR], evaluation emulated.Element[FR]) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManagerEmulated[FR]) getLazyClaim(wire *WireEmulated[FR]) *eqTimesGateEvalSumcheckLazyClaimsEmulated[FR] { + return m.claimsMap[wire] +} + +func (m *claimsManagerEmulated[FR]) deleteClaim(wire *WireEmulated[FR]) { + delete(m.claimsMap, wire) +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment +} + +func newClaimsManager(c Circuit, assignment WireAssignment) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]big.Int, 0, wire.NbClaims()), + claimedEvaluations: make([]big.Int, wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []big.Int, evaluation big.Int) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getClaim(engine *sumcheck.BigIntEngine, wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + engine: engine, + } + + if wire.IsInput() { + res.inputPreprocessors = []sumcheck.NativeMultilinear{m.assignment[wire]} + } else { + res.inputPreprocessors = make([]sumcheck.NativeMultilinear, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.assignment[inputW].Clone() + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]big.Int // x in the paper + claimedEvaluations []big.Int // y in the paper + manager *claimsManager +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]big.Int // x in the paper + claimedEvaluations []big.Int // y in the paper + manager *claimsManager + engine *sumcheck.BigIntEngine + inputPreprocessors []sumcheck.NativeMultilinear // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq sumcheck.NativeMultilinear // ∑_i τ_i eq(x_i, -) +} + +func (e *eqTimesGateEvalSumcheckClaims) NbClaims() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckClaims) NbVars() int { + return len(e.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff *big.Int) sumcheck.NativePolynomial { + varsNum := c.NbVars() + eqLength := 1 << varsNum + claimsNum := c.NbClaims() + + // initialize the eq tables + c.eq = make(sumcheck.NativeMultilinear, eqLength) + for i := 0; i < eqLength; i++ { + c.eq[i] = new(big.Int) + } + c.eq[0] = c.engine.One() + sumcheck.Eq(c.engine, c.eq, sumcheck.ReferenceBigIntSlice(c.evaluationPoints[0])) + + newEq := make(sumcheck.NativeMultilinear, eqLength) + for i := 0; i < eqLength; i++ { + newEq[i] = new(big.Int) + } + aI := new(big.Int).Set(combinationCoeff) + + for k := 1; k < claimsNum; k++ { // TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(aI) + sumcheck.EqAcc(c.engine, c.eq, newEq, sumcheck.ReferenceBigIntSlice(c.evaluationPoints[k])) + if k+1 < claimsNum { + aI.Mul(aI, combinationCoeff) + } + } + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + return c.computeGJ() +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() sumcheck.NativePolynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]sumcheck.NativeMultilinear, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make(sumcheck.NativePolynomial, degGJ) + for i := range gJ { + gJ[i] = new(big.Int) + } + + step := new(big.Int) + res := make([]*big.Int, degGJ) + for i := range res { + res[i] = new(big.Int) + } + operands := make([]*big.Int, degGJ*nbInner) + for i := range operands { + operands[i] = new(big.Int) + } + + for i := 0; i < nbOuter; i++ { + block := nbOuter + i + for j := 0; j < nbInner; j++ { + // TODO: instead of set can assign? + step.Set(s[j][i]) + operands[j].Set(s[j][block]) + step = c.engine.Sub(operands[j], step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j] = c.engine.Add(operands[(d-1)*nbInner+j], step) + } + } + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(c.engine, operands[_s+1:_e]...) + summand = c.engine.Mul(summand, operands[_s]) + res[d] = c.engine.Add(res[d], summand) + _s, _e = _e, _e+nbInner + } + } + for i := 0; i < degGJ; i++ { + gJ[i] = c.engine.Add(gJ[i], res[i]) + } + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element *big.Int) sumcheck.NativePolynomial { + for i := 0; i < len(c.inputPreprocessors); i++ { + sumcheck.Fold(c.engine, c.inputPreprocessors[i], element) + } + sumcheck.Fold(c.engine, c.eq, element) + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) ProverFinalEval(r []*big.Int) sumcheck.NativeEvaluationProof { + + //defer the proof, return list of claims + evaluations := make([]big.Int, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + sumcheck.Fold(c.engine, puI, r[len(r)-1]) + puI0 := new(big.Int).Set(puI[0]) + c.manager.add(in, sumcheck.DereferenceBigIntSlice(r), *puI0) + evaluations = append(evaluations, *puI0) + } + } + + return evaluations +} + +func (e *eqTimesGateEvalSumcheckClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func setup(current *big.Int, target *big.Int, c Circuit, assignment WireAssignment, options ...OptionGkr) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func ChallengeNamesEmulated[FR emulated.FieldParams](sorted []*WireEmulated[FR], logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func (v *GKRVerifier[FR]) getChallengesFr(transcript *fiatshamir.Transcript, names []string) (challenges []emulated.Element[FR], err error) { + challenges = make([]emulated.Element[FR], len(names)) + var challenge emulated.Element[FR] + var fr FR + for i, name := range names { + nativeChallenge, err := transcript.ComputeChallenge(name) + if err != nil { + return nil, fmt.Errorf("compute challenge %s: %w", names, err) + } + // TODO: when implementing better way (construct from limbs instead of bits) then change + chBts := bits.ToBinary(v.api, nativeChallenge, bits.WithNbDigits(fr.Modulus().BitLen())) + challenge = *v.f.FromBits(chBts...) + challenges[i] = challenge + + } + return challenges, nil +} + +// Prove consistency of the claimed assignment +func Prove(current *big.Int, target *big.Int, c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.SettingsBigInt, options ...OptionGkr) (NativeProofs, error) { + be := sumcheck.NewBigIntEngine(target) + o, err := setup(current, target, c, assignment, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment) + + proof := make(NativeProofs, len(c)) + challengeNames := getFirstChallengeNames(o.nbVars, o.transcriptPrefix) + // firstChallenge called rho in the paper + firstChallenge := make([]*big.Int, len(challengeNames)) + for i := 0; i < len(challengeNames); i++ { + firstChallenge[i], _, err = sumcheck.DeriveChallengeProver(o.transcript, challengeNames[i:], nil) + if err != nil { + return nil, err + } + } + + var baseChallenge []*big.Int + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + evaluation := sumcheck.Eval(be, assignment[wire], firstChallenge) + claims.add(wire, sumcheck.DereferenceBigIntSlice(firstChallenge), *evaluation) + } + + claim := claims.getClaim(be, wire) + var finalEvalProofLen int + + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.NativeProof{ + RoundPolyEvaluations: []sumcheck.NativePolynomial{}, + FinalEvalProof: sumcheck.NativeDeferredEvalProof([]big.Int{}), + } + } else { + proof[i], err = sumcheck.Prove( + current, target, claim, + ) + if err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof + switch finalEvalProof := finalEvalProof.(type) { + case nil: + finalEvalProofCasted := sumcheck.NativeDeferredEvalProof([]big.Int{}) + proof[i].FinalEvalProof = finalEvalProofCasted + case []big.Int: + finalEvalProofLen = len(finalEvalProof) + finalEvalProofCasted := sumcheck.NativeDeferredEvalProof(finalEvalProof) + proof[i].FinalEvalProof = finalEvalProofCasted + default: + return nil, fmt.Errorf("finalEvalProof is not of type DeferredEvalProof") + } + + baseChallenge = make([]*big.Int, finalEvalProofLen) + for i := 0; i < finalEvalProofLen; i++ { + baseChallenge[i] = &finalEvalProof.([]big.Int)[i] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete, +// Use valueOfProof[FR](proof) to convert nativeproof by prover into nonnativeproof used by in-circuit verifier +func (v *GKRVerifier[FR]) Verify(api frontend.API, c CircuitEmulated[FR], assignment WireAssignmentEmulated[FR], proof Proofs[FR], transcriptSettings fiatshamir.SettingsEmulated[FR], options ...OptionEmulated[FR]) error { + o, err := v.setup(api, c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + sumcheck_verifier, err := sumcheck.NewVerifier[FR](api) + if err != nil { + return err + } + + claims := newClaimsManagerEmulated[FR](c, assignment, *v) + var firstChallenge []emulated.Element[FR] + firstChallenge, err = v.getChallengesFr(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + var baseChallenge []emulated.Element[FR] + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + if wire.IsOutput() { + var evaluation emulated.Element[FR] + evaluationPtr, err := v.p.EvalMultilinear(polynomial.FromSlice(firstChallenge), assignment[wire]) + if err != nil { + return err + } + evaluation = *evaluationPtr + claims.add(wire, firstChallenge, evaluation) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof + claim := claims.getLazyClaim(wire) + + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + // make sure finalevalproof is of type deferred for gkr + var proofLen int + switch proof := finalEvalProof.(type) { + case []emulated.Element[FR]: + proofLen = len(sumcheck.DeferredEvalProof[FR](proof)) + default: + return fmt.Errorf("finalEvalProof is not of type DeferredEvalProof") + } + + if (finalEvalProof != nil && proofLen != 0) || len(proofW.RoundPolyEvaluations) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + var evaluation emulated.Element[FR] + evaluationPtr, err := v.p.EvalMultilinear(polynomial.FromSlice(claim.evaluationPoints[0]), assignment[wire]) + if err != nil { + return err + } + evaluation = *evaluationPtr + v.f.AssertIsEqual(&claim.claimedEvaluations[0], &evaluation) + } + } else if err = sumcheck_verifier.Verify( + claim, proof[i], + ); err == nil { + switch proof := finalEvalProof.(type) { + case []emulated.Element[FR]: + baseChallenge = sumcheck.DeferredEvalProof[FR](proof) + default: + return fmt.Errorf("finalEvalProof is not of type DeferredEvalProof") + } + _ = baseChallenge + } else { + return err + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate[*sumcheck.BigIntEngine, *big.Int]{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +type IdentityGate[AE sumcheck.ArithEngine[E], E element] struct{} + +func (IdentityGate[AE, E]) Evaluate(api AE, input ...E) E { + return input[0] +} + +func (IdentityGate[AE, E]) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsListEmulated[FR emulated.FieldParams](c CircuitEmulated[FR], indexes map[*WireEmulated[FR]]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortDataEmulated[FR emulated.FieldParams] struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*WireEmulated[FR]]int + leastReady int +} + +func (d *topSortDataEmulated[FR]) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMapEmulated[FR emulated.FieldParams](c CircuitEmulated[FR]) map[*WireEmulated[FR]]int { + res := make(map[*WireEmulated[FR]]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusListEmulated[FR emulated.FieldParams](c CircuitEmulated[FR]) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// TODO: Have this use algo_utils.TopologicalSort underneath + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit, target *big.Int) WireAssignment { + + engine := sumcheck.NewBigIntEngine(target) + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = utils.Max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]*big.Int, nbInstances) + } + } + + parallel.Execute(nbInstances, func(start, end int) { + ins := make([]*big.Int, maxNbIns) + for i := start; i < end; i++ { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(engine, ins[:len(w.Inputs)]...) + } + } + } + }) + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + if aW != nil { + return len(aW) + } + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + if aW != nil { + return aW.NumVars() + } + } + panic("empty assignment") +} + +func topologicalSortEmulated[FR emulated.FieldParams](c CircuitEmulated[FR]) []*WireEmulated[FR] { + var data topSortDataEmulated[FR] + data.index = indexMapEmulated(c) + data.outputs = outputsListEmulated(c, data.index) + data.status = statusListEmulated(c) + sorted := make([]*WireEmulated[FR], len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +func (a WireAssignmentEmulated[FR]) NumInstances() int { + for _, aW := range a { + if aW != nil { + return len(aW) + } + } + panic("empty assignment") +} + +func (a WireAssignmentEmulated[FR]) NumVars() int { + for _, aW := range a { + if aW != nil { + return aW.NumVars() + } + } + panic("empty assignment") +} + +func (p Proofs[FR]) Serialize() []emulated.Element[FR] { + size := 0 + for i := range p { + for j := range p[i].RoundPolyEvaluations { + size += len(p[i].RoundPolyEvaluations[j]) + } + switch v := p[i].FinalEvalProof.(type) { + case sumcheck.DeferredEvalProof[FR]: + size += len(v) + } + } + + res := make([]emulated.Element[FR], 0, size) + for i := range p { + for j := range p[i].RoundPolyEvaluations { + res = append(res, p[i].RoundPolyEvaluations[j]...) + } + switch v := p[i].FinalEvalProof.(type) { + case sumcheck.DeferredEvalProof[FR]: + res = append(res, v...) + } + } + if len(res) != size { + panic("bug") // TODO: Remove + } + return res +} + +func computeLogNbInstances[FR emulated.FieldParams](wires []*WireEmulated[FR], serializedProofLen int) int { + partialEvalElemsPerVar := 0 + for _, w := range wires { + if !w.noProof() { + partialEvalElemsPerVar += w.Gate.Degree() + 1 + } + serializedProofLen -= w.nbUniqueOutputs + } + return serializedProofLen / partialEvalElemsPerVar +} + +type variablesReader[FR emulated.FieldParams] []emulated.Element[FR] + +func (r *variablesReader[FR]) nextN(n int) []emulated.Element[FR] { + res := (*r)[:n] + *r = (*r)[n:] + return res +} + +func (r *variablesReader[FR]) hasNextN(n int) bool { + return len(*r) >= n +} + +func DeserializeProof[FR emulated.FieldParams](sorted []*WireEmulated[FR], serializedProof []emulated.Element[FR]) (Proofs[FR], error) { + proof := make(Proofs[FR], len(sorted)) + logNbInstances := computeLogNbInstances(sorted, len(serializedProof)) + + reader := variablesReader[FR](serializedProof) + for i, wI := range sorted { + if !wI.noProof() { + proof[i].RoundPolyEvaluations = make([]polynomial.Univariate[FR], logNbInstances) + for j := range proof[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[j] = reader.nextN(wI.Gate.Degree() + 1) + } + } + proof[i].FinalEvalProof = reader.nextN(wI.nbUniqueInputs()) + } + if reader.hasNextN(1) { + return nil, fmt.Errorf("proof too long: expected %d encountered %d", len(serializedProof)-len(reader), len(serializedProof)) + } + return proof, nil +} + +type element any + +type MulGate[AE sumcheck.ArithEngine[E], E element] struct{} + +func (g MulGate[AE, E]) Evaluate(api AE, x ...E) E { + if len(x) != 2 { + panic("mul has fan-in 2") + } + return api.Mul(x[0], x[1]) +} + +// TODO: Degree must take nbInputs as an argument and return degree = nbInputs +func (g MulGate[AE, E]) Degree() int { + return 2 +} + +type AddGate[AE sumcheck.ArithEngine[E], E element] struct{} + +func (a AddGate[AE, E]) Evaluate(api AE, v ...E) E { + switch len(v) { + case 0: + return api.Const(big.NewInt(0)) + case 1: + return v[0] + } + rest := v[2:] + res := api.Add(v[0], v[1]) + for _, e := range rest { + res = api.Add(res, e) + } + return res +} + +func (a AddGate[AE, E]) Degree() int { + return 1 +} diff --git a/std/recursion/gkr/gkr_nonnative_test.go b/std/recursion/gkr/gkr_nonnative_test.go new file mode 100644 index 0000000000..c1fd696e3a --- /dev/null +++ b/std/recursion/gkr/gkr_nonnative_test.go @@ -0,0 +1,957 @@ +package gkrnonative + +import ( + "encoding/json" + "fmt" + gohash "hash" + "math/big" + "os" + "path/filepath" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + fpbn254 "github.com/consensys/gnark-crypto/ecc/bn254/fp" + frbn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + fiatshamir "github.com/consensys/gnark/std/fiat-shamir" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/std/math/polynomial" + "github.com/consensys/gnark/std/recursion" + "github.com/consensys/gnark/std/recursion/gkr/utils" + "github.com/consensys/gnark/std/recursion/sumcheck" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/assert" +) + +var Gates = map[string]Gate{ + "identity": IdentityGate[*sumcheck.BigIntEngine, *big.Int]{}, + "add": AddGate[*sumcheck.BigIntEngine, *big.Int]{}, + "mul": MulGate[*sumcheck.BigIntEngine, *big.Int]{}, +} + +func TestGkrVectorsEmulated(t *testing.T) { + current := ecc.BN254.ScalarField() + var fp emparams.BN254Fp + testDirPath := "./test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + t.Error(err) + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() && filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path, *current, *fp.Modulus())) + t.Run(noExt+"_verifier", generateTestVerifier[emparams.BN254Fp](path)) + } + } +} + +func proofEquals(expected NativeProofs, seen NativeProofs) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + // todo: REMOVE GKR PROOF ABSTRACTION FROM PROOFEQUALS + xfinalEvalProofSeen := xSeen.FinalEvalProof + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.(sumcheck.NativeDeferredEvalProof); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := utils.SliceEqualsBigInt(x.FinalEvalProof.(sumcheck.NativeDeferredEvalProof), + xfinalEvalProofSeen.(sumcheck.NativeDeferredEvalProof)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + + roundPolyEvals := make([]sumcheck.NativePolynomial, len(x.RoundPolyEvaluations)) + copy(roundPolyEvals, x.RoundPolyEvaluations) + + roundPolyEvalsSeen := make([]sumcheck.NativePolynomial, len(xSeen.RoundPolyEvaluations)) + copy(roundPolyEvalsSeen, xSeen.RoundPolyEvaluations) + + for i, poly := range roundPolyEvals { + if err := utils.SliceEqualsBigInt(sumcheck.DereferenceBigIntSlice(poly), sumcheck.DereferenceBigIntSlice(roundPolyEvalsSeen[i])); err != nil { + return err + } + } + } + return nil +} + +func generateTestProver(path string, current big.Int, target big.Int) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path, target) + assert.NoError(t, err) + proof, err := Prove(¤t, &target, testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHashBigInt(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier[FR emulated.FieldParams](path string) func(t *testing.T) { + + return func(t *testing.T) { + + testCase, err := getTestCase[FR](path) + assert := test.NewAssert(t) + assert.NoError(err) + + assignment := &GkrVerifierCircuitEmulated[FR]{ + Input: testCase.Input, + Output: testCase.Output, + SerializedProof: testCase.Proof.Serialize(), + ToFail: false, + TestCaseName: path, + } + + validCircuit := &GkrVerifierCircuitEmulated[FR]{ + Input: make([][]emulated.Element[FR], len(testCase.Input)), + Output: make([][]emulated.Element[FR], len(testCase.Output)), + SerializedProof: make([]emulated.Element[FR], len(assignment.SerializedProof)), + ToFail: false, + TestCaseName: path, + } + + fillWithBlanks(validCircuit.Input, len(testCase.Input[0])) + fillWithBlanks(validCircuit.Output, len(testCase.Input[0])) + + assert.CheckCircuit(validCircuit, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.WithValidAssignment(assignment)) + } +} + +type GkrVerifierCircuitEmulated[FR emulated.FieldParams] struct { + Input [][]emulated.Element[FR] + Output [][]emulated.Element[FR] `gnark:",public"` + SerializedProof []emulated.Element[FR] + ToFail bool + TestCaseName string +} + +func (c *GkrVerifierCircuitEmulated[FR]) Define(api frontend.API) error { + var fr FR + var testCase *TestCaseVerifier[FR] + var proof Proofs[FR] + var err error + + v, err := NewGKRVerifier[FR](api) + if err != nil { + return fmt.Errorf("new verifier: %w", err) + } + + if testCase, err = getTestCase[FR](c.TestCaseName); err != nil { + return err + } + sorted := topologicalSortEmulated(testCase.Circuit) + + if proof, err = DeserializeProof(sorted, c.SerializedProof); err != nil { + return err + } + assignment := makeInOutAssignment(testCase.Circuit, c.Input, c.Output) + + // initiating hash in bitmode + hsh, err := recursion.NewHash(api, fr.Modulus(), true) + if err != nil { + return err + } + + return v.Verify(api, testCase.Circuit, assignment, proof, fiatshamir.WithHashFr[FR](hsh)) +} + +func makeInOutAssignment[FR emulated.FieldParams](c CircuitEmulated[FR], inputValues [][]emulated.Element[FR], outputValues [][]emulated.Element[FR]) WireAssignmentEmulated[FR] { + sorted := topologicalSortEmulated(c) + res := make(WireAssignmentEmulated[FR], len(inputValues)+len(outputValues)) + inI, outI := 0, 0 + for _, w := range sorted { + if w.IsInput() { + res[w] = inputValues[inI] + inI++ + } else if w.IsOutput() { + res[w] = outputValues[outI] + outI++ + } + } + return res +} + +func fillWithBlanks[FR emulated.FieldParams](slice [][]emulated.Element[FR], size int) { + for i := range slice { + slice[i] = make([]emulated.Element[FR], size) + } +} + +type TestCaseVerifier[FR emulated.FieldParams] struct { + Circuit CircuitEmulated[FR] + Hash utils.HashDescription + Proof Proofs[FR] + Input [][]emulated.Element[FR] + Output [][]emulated.Element[FR] + Name string +} +type TestCaseInfo struct { + Hash utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]interface{}) + +func getTestCase[FR emulated.FieldParams](path string) (*TestCaseVerifier[FR], error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + cse, ok := testCases[path].(*TestCaseVerifier[FR]) + if !ok { + var bytes []byte + cse = &TestCaseVerifier[FR]{} + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + if cse.Circuit, err = getCircuitEmulated[FR](filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + + nativeProofs := unmarshalProof(info.Proof) + proofs := make(Proofs[FR], len(nativeProofs)) + for i, proof := range nativeProofs { + proofs[i] = sumcheck.ValueOfProof[FR](proof) + } + cse.Proof = proofs + + cse.Input = utils.ToVariableSliceSliceFr[FR](info.Input) + cse.Output = utils.ToVariableSliceSliceFr[FR](info.Output) + cse.Hash = info.Hash + cse.Name = path + testCases[path] = cse + } else { + return nil, err + } + } + + return cse, nil +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]interface{}) + +func getCircuit(path string) (circuit Circuit, err error) { + path, err = filepath.Abs(path) + if err != nil { + return + } + var ok bool + if circuit, ok = circuitCache[path].(Circuit); ok { + return + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit, err = toCircuit(circuitInfo) + if err == nil { + circuitCache[path] = circuit + } + } + } + return +} + +func getCircuitEmulated[FR emulated.FieldParams](path string) (circuit CircuitEmulated[FR], err error) { + path, err = filepath.Abs(path) + if err != nil { + return + } + var ok bool + if circuit, ok = circuitCache[path].(CircuitEmulated[FR]); ok { + return + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit, err = ToCircuitEmulated[FR](circuitInfo) + if err == nil { + circuitCache[path] = circuit + } + } + } + return +} + +func ToCircuitEmulated[FR emulated.FieldParams](c CircuitInfo) (circuit CircuitEmulated[FR], err error) { + var GatesEmulated = map[string]GateEmulated[FR]{ + "identity": IdentityGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{}, + "add": AddGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{}, + "mul": MulGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{}, + } + + circuit = make(CircuitEmulated[FR], len(c)) + for i, wireInfo := range c { + circuit[i].Inputs = make([]*WireEmulated[FR], len(wireInfo.Inputs)) + for iAsInput, iAsWire := range wireInfo.Inputs { + input := &circuit[iAsWire] + circuit[i].Inputs[iAsInput] = input + } + + var found bool + if circuit[i].Gate, found = GatesEmulated[wireInfo.Gate]; !found && wireInfo.Gate != "" { + err = fmt.Errorf("undefined gate \"%s\"", wireInfo.Gate) + } + } + + return +} + +func toCircuit(c CircuitInfo) (circuit Circuit, err error) { + + circuit = make(Circuit, len(c)) + for i, wireInfo := range c { + circuit[i].Inputs = make([]*Wire, len(wireInfo.Inputs)) + for iAsInput, iAsWire := range wireInfo.Inputs { + input := &circuit[iAsWire] + circuit[i].Inputs[iAsInput] = input + } + + var found bool + if circuit[i].Gate, found = Gates[wireInfo.Gate]; !found && wireInfo.Gate != "" { + err = fmt.Errorf("undefined gate \"%s\"", wireInfo.Gate) + } + } + + return +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof [][]uint64 `json:"finalEvalProof"` + RoundPolyEvaluations [][][]uint64 `json:"roundPolyEvaluations"` +} + +func unmarshalProof(printable []PrintableSumcheckProof) (proof NativeProofs) { + proof = make(NativeProofs, len(printable)) + + for i := range printable { + if printable[i].FinalEvalProof != nil { + finalEvalProof := make(sumcheck.NativeDeferredEvalProof, len(printable[i].FinalEvalProof)) + for k, val := range printable[i].FinalEvalProof { + var temp big.Int + temp.SetUint64(val[0]) + for _, v := range val[1:] { + temp.Lsh(&temp, 64).Add(&temp, new(big.Int).SetUint64(v)) + } + finalEvalProof[k] = temp + } + proof[i].FinalEvalProof = finalEvalProof + } else { + proof[i].FinalEvalProof = nil + } + + proof[i].RoundPolyEvaluations = make([]sumcheck.NativePolynomial, len(printable[i].RoundPolyEvaluations)) + for k, evals := range printable[i].RoundPolyEvaluations { + proof[i].RoundPolyEvaluations[k] = make(sumcheck.NativePolynomial, len(evals)) + for j, eval := range evals { + var temp big.Int + temp.SetUint64(eval[0]) + for _, v := range eval[1:] { + temp.Lsh(&temp, 64).Add(&temp, new(big.Int).SetUint64(v)) + } + proof[i].RoundPolyEvaluations[k][j] = &temp + } + } + } + return proof +} + +func TestLoadCircuit(t *testing.T) { + type FR = emulated.BN254Fp + c, err := getCircuitEmulated[FR]("test_vectors/resources/two_identity_gates_composed_single_input.json") + assert.NoError(t, err) + assert.Equal(t, []*WireEmulated[FR]{}, c[0].Inputs) + assert.Equal(t, []*WireEmulated[FR]{&c[0]}, c[1].Inputs) + assert.Equal(t, []*WireEmulated[FR]{&c[1]}, c[2].Inputs) +} + +func TestTopSortTrivial(t *testing.T) { + type FR = emulated.BN254Fp + c := make(CircuitEmulated[FR], 2) + c[0].Inputs = []*WireEmulated[FR]{&c[1]} + sorted := topologicalSortEmulated(c) + assert.Equal(t, []*WireEmulated[FR]{&c[1], &c[0]}, sorted) +} + +func TestTopSortSingleGate(t *testing.T) { + type FR = emulated.BN254Fp + c := make(CircuitEmulated[FR], 3) + c[0].Inputs = []*WireEmulated[FR]{&c[1], &c[2]} + sorted := topologicalSortEmulated(c) + expected := []*WireEmulated[FR]{&c[1], &c[2], &c[0]} + assert.True(t, utils.SliceEqual(sorted, expected)) //TODO: Remove + utils.AssertSliceEqual(t, sorted, expected) + assert.Equal(t, c[0].nbUniqueOutputs, 0) + assert.Equal(t, c[1].nbUniqueOutputs, 1) + assert.Equal(t, c[2].nbUniqueOutputs, 1) +} + +func TestTopSortDeep(t *testing.T) { + type FR = emulated.BN254Fp + c := make(CircuitEmulated[FR], 4) + c[0].Inputs = []*WireEmulated[FR]{&c[2]} + c[1].Inputs = []*WireEmulated[FR]{&c[3]} + c[2].Inputs = []*WireEmulated[FR]{} + c[3].Inputs = []*WireEmulated[FR]{&c[0]} + sorted := topologicalSortEmulated(c) + assert.Equal(t, []*WireEmulated[FR]{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + type FR = emulated.BN254Fp + c := make(CircuitEmulated[FR], 10) + c[0].Inputs = []*WireEmulated[FR]{&c[3], &c[8]} + c[1].Inputs = []*WireEmulated[FR]{&c[6]} + c[2].Inputs = []*WireEmulated[FR]{&c[4]} + c[3].Inputs = []*WireEmulated[FR]{} + c[4].Inputs = []*WireEmulated[FR]{} + c[5].Inputs = []*WireEmulated[FR]{&c[9]} + c[6].Inputs = []*WireEmulated[FR]{&c[9]} + c[7].Inputs = []*WireEmulated[FR]{&c[9], &c[5], &c[2]} + c[8].Inputs = []*WireEmulated[FR]{&c[4], &c[3]} + c[9].Inputs = []*WireEmulated[FR]{} + + sorted := topologicalSortEmulated(c) + sortedExpected := []*WireEmulated[FR]{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +var mimcSnarkTotalCalls = 0 + +// todo add ark +type MiMCCipherGate struct { +} + +func (m MiMCCipherGate) Evaluate(api *sumcheck.BigIntEngine, input ...*big.Int) *big.Int { + mimcSnarkTotalCalls++ + + if len(input) != 2 { + panic("mimc has fan-in 2") + } + sum := api.Add(input[0], input[1]) + sumSquared := api.Mul(sum, sum) + sumCubed := api.Mul(sumSquared, sum) + return api.Mul(sumCubed, sum) +} + +func (m MiMCCipherGate) Degree() int { + return 7 +} + +type _select int + +func init() { + Gates["mimc"] = MiMCCipherGate{} + Gates["select-input-3"] = _select(2) +} + +func (g _select) Evaluate(_ *sumcheck.BigIntEngine, in ...*big.Int) *big.Int { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} + +type TestCase struct { + Current big.Int + Target big.Int + Circuit Circuit + Hash gohash.Hash + Proof NativeProofs + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +func (p *PrintableSumcheckProof) UnmarshalJSON(data []byte) error { + var temp struct { + FinalEvalProof [][]uint64 `json:"finalEvalProof"` + RoundPolyEvaluations [][][]uint64 `json:"roundPolyEvaluations"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + p.FinalEvalProof = temp.FinalEvalProof + + p.RoundPolyEvaluations = make([][][]uint64, len(temp.RoundPolyEvaluations)) + for i, arr2D := range temp.RoundPolyEvaluations { + p.RoundPolyEvaluations[i] = make([][]uint64, len(arr2D)) + for j, arr1D := range arr2D { + p.RoundPolyEvaluations[i][j] = make([]uint64, len(arr1D)) + for k, v := range arr1D { + p.RoundPolyEvaluations[i][j][k] = uint64(v) + } + } + } + return nil +} + +func newTestCase(path string, target big.Int) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash gohash.Hash + if _hash, err = utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + + proof := unmarshalProof(info.Proof) + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []big.Int + if wireAssignment, err = utils.SliceToBigIntSlice(assignmentRaw); err != nil { + return nil, err + } + fullAssignment[w] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + inOutAssignment[w] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + } + } + + fullAssignment.Complete(circuit, &target) + + for _, w := range sorted { + if w.IsOutput() { + + if err = utils.SliceEqualsBigInt(sumcheck.DereferenceBigIntSlice(inOutAssignment[w]), sumcheck.DereferenceBigIntSlice(fullAssignment[w])); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase.(*TestCase), nil +} + +type ProjAddGkrVerifierCircuit[FR emulated.FieldParams] struct { + Circuit CircuitEmulated[FR] + Input [][]emulated.Element[FR] + Output [][]emulated.Element[FR] `gnark:",public"` + SerializedProof []emulated.Element[FR] +} + +func (c *ProjAddGkrVerifierCircuit[FR]) Define(api frontend.API) error { + var fr FR + var proof Proofs[FR] + var err error + + v, err := NewGKRVerifier[FR](api) + if err != nil { + return fmt.Errorf("new verifier: %w", err) + } + + sorted := topologicalSortEmulated(c.Circuit) + + if proof, err = DeserializeProof(sorted, c.SerializedProof); err != nil { + return err + } + assignment := makeInOutAssignment(c.Circuit, c.Input, c.Output) + + // initiating hash in bitmode, since bn254 basefield is bigger than scalarfield + hsh, err := recursion.NewHash(api, fr.Modulus(), true) + if err != nil { + return err + } + + return v.Verify(api, c.Circuit, assignment, proof, fiatshamir.WithHashFr[FR](hsh)) +} + +func testDblAddSelectGKRInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, target *big.Int, inputs [][]*big.Int, outputs [][]*big.Int) { + folding := []*big.Int{ + big.NewInt(1), + big.NewInt(2), + big.NewInt(3), + big.NewInt(4), + big.NewInt(5), + big.NewInt(6), + } + c := make(Circuit, 8) + c[7] = Wire{ + Gate: sumcheck.DblAddSelectGate[*sumcheck.BigIntEngine, *big.Int]{Folding: folding}, + Inputs: []*Wire{&c[0], &c[1], &c[2], &c[3], &c[4], &c[5], &c[6]}, + } + + res := make([]*big.Int, len(inputs[0])) + for i := 0; i < len(inputs[0]); i++ { + res[i] = c[7].Gate.Evaluate(sumcheck.NewBigIntEngine(target), inputs[0][i], inputs[1][i], inputs[2][i], inputs[3][i], inputs[4][i], inputs[5][i], inputs[6][i]) + } + + foldingEmulated := make([]emulated.Element[FR], len(folding)) + for i, f := range folding { + foldingEmulated[i] = emulated.ValueOf[FR](f) + } + cEmulated := make(CircuitEmulated[FR], len(c)) + cEmulated[7] = WireEmulated[FR]{ + Gate: sumcheck.DblAddSelectGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{ + Folding: polynomial.FromSlice(foldingEmulated), + }, + Inputs: []*WireEmulated[FR]{&cEmulated[0], &cEmulated[1], &cEmulated[2], &cEmulated[3], &cEmulated[4], &cEmulated[5], &cEmulated[6]}, + } + + assert := test.NewAssert(t) + + hash, err := recursion.NewShort(current, target) + if err != nil { + t.Errorf("new short hash: %v", err) + return + } + t.Log("Evaluating all circuit wires") + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(c) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []*big.Int + if w.IsInput() { + if inI == len(inputs) { + t.Errorf("fewer input in vector than in circuit") + return + } + assignmentRaw = inputs[inI] + inI++ + } else if w.IsOutput() { + if outI == len(outputs) { + t.Errorf("fewer output in vector than in circuit") + return + } + assignmentRaw = outputs[outI] + outI++ + } + + if assignmentRaw != nil { + var wireAssignment []big.Int + wireAssignment, err := utils.SliceToBigIntSlice(assignmentRaw) + assert.NoError(err) + fullAssignment[w] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + inOutAssignment[w] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + } + } + + fullAssignment.Complete(c, target) + + for _, w := range sorted { + if w.IsOutput() { + if err = utils.SliceEqualsBigInt(sumcheck.DereferenceBigIntSlice(inOutAssignment[w]), sumcheck.DereferenceBigIntSlice(fullAssignment[w])); err != nil { + t.Errorf("assignment mismatch: %v", err) + } + + } + } + + t.Log("Circuit evaluation complete") + proof, err := Prove(current, target, c, fullAssignment, fiatshamir.WithHashBigInt(hash)) + assert.NoError(err) + t.Log("Proof complete") + + proofEmulated := make(Proofs[FR], len(proof)) + for i, proof := range proof { + proofEmulated[i] = sumcheck.ValueOfProof[FR](proof) + } + + validCircuit := &ProjAddGkrVerifierCircuit[FR]{ + Circuit: cEmulated, + Input: make([][]emulated.Element[FR], len(inputs)), + Output: make([][]emulated.Element[FR], len(outputs)), + SerializedProof: proofEmulated.Serialize(), + } + + validAssignment := &ProjAddGkrVerifierCircuit[FR]{ + Circuit: cEmulated, + Input: make([][]emulated.Element[FR], len(inputs)), + Output: make([][]emulated.Element[FR], len(outputs)), + SerializedProof: proofEmulated.Serialize(), + } + + for i := range inputs { + validCircuit.Input[i] = make([]emulated.Element[FR], len(inputs[i])) + validAssignment.Input[i] = make([]emulated.Element[FR], len(inputs[i])) + for j := range inputs[i] { + validAssignment.Input[i][j] = emulated.ValueOf[FR](inputs[i][j]) + } + } + + for i := range outputs { + validCircuit.Output[i] = make([]emulated.Element[FR], len(outputs[i])) + validAssignment.Output[i] = make([]emulated.Element[FR], len(outputs[i])) + for j := range outputs[i] { + validAssignment.Output[i][j] = emulated.ValueOf[FR](outputs[i][j]) + } + } + + err = test.IsSolved(validCircuit, validAssignment, current) + assert.NoError(err) +} + +func ElementToBigInt(element fpbn254.Element) *big.Int { + var temp big.Int + return element.BigInt(&temp) +} + +func TestProjDblAddSelectGKR(t *testing.T) { + var P bn254.G1Affine + var Q bn254.G1Affine + var U bn254.G1Affine + var one fpbn254.Element + one.SetOne() + var zero fpbn254.Element + zero.SetZero() + + var s frbn254.Element + s.SetOne() + var r frbn254.Element + r.SetOne() + P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) + Q.ScalarMultiplicationBase(r.BigInt(new(big.Int))) + U.Add(&P, &Q) + + result, err := new(big.Int).SetString("21888242871839275222246405745257275088696311157297823662689037894645226206973", 10) + if !err { + panic("error result") + } + + var fp emparams.BN254Fp + testDblAddSelectGKRInstance[emparams.BN254Fp](t, ecc.BN254.ScalarField(), fp.Modulus(), [][]*big.Int{{ElementToBigInt(P.X), ElementToBigInt(P.X)}, {ElementToBigInt(P.Y), ElementToBigInt(P.Y)}, {ElementToBigInt(one), ElementToBigInt(one)}, {ElementToBigInt(zero), ElementToBigInt(zero)}, {ElementToBigInt(one), ElementToBigInt(one)}, {ElementToBigInt(zero), ElementToBigInt(zero)}, {ElementToBigInt(one), ElementToBigInt(one)}}, [][]*big.Int{{result, result}}) +} + +func testMultipleDblAddSelectGKRInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, target *big.Int, inputs [][]*big.Int, outputs [][]*big.Int) { + folding := []*big.Int{ + big.NewInt(1), + big.NewInt(2), + big.NewInt(3), + big.NewInt(4), + big.NewInt(5), + big.NewInt(6), + } + c := make(Circuit, 8) + // check rlc of inputs to second layer is equal to output + c[7] = Wire{ + Gate: sumcheck.DblAddSelectGate[*sumcheck.BigIntEngine, *big.Int]{Folding: folding}, + Inputs: []*Wire{&c[0], &c[1], &c[2], &c[3], &c[4], &c[5], &c[6]}, + } + + res := make([]*big.Int, len(inputs[0])) + for i := 0; i < len(inputs[0]); i++ { + res[i] = c[7].Gate.Evaluate(sumcheck.NewBigIntEngine(target), inputs[0][i], inputs[1][i], inputs[2][i], inputs[3][i], inputs[4][i], inputs[5][i], inputs[6][i]) + } + fmt.Println("res", res) + + foldingEmulated := make([]emulated.Element[FR], len(folding)) + for i, f := range folding { + foldingEmulated[i] = emulated.ValueOf[FR](f) + } + cEmulated := make(CircuitEmulated[FR], len(c)) + cEmulated[7] = WireEmulated[FR]{ + Gate: sumcheck.DblAddSelectGate[*sumcheck.EmuEngine[FR], *emulated.Element[FR]]{ + Folding: polynomial.FromSlice(foldingEmulated), + }, + Inputs: []*WireEmulated[FR]{&cEmulated[0], &cEmulated[1], &cEmulated[2], &cEmulated[3], &cEmulated[4], &cEmulated[5], &cEmulated[6]}, + } + + assert := test.NewAssert(t) + + hash, err := recursion.NewShort(current, target) + if err != nil { + t.Errorf("new short hash: %v", err) + return + } + t.Log("Evaluating all circuit wires") + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(c) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []*big.Int + if w.IsInput() { + if inI == len(inputs) { + t.Errorf("fewer input in vector than in circuit") + return + } + assignmentRaw = inputs[inI] + inI++ + } else if w.IsOutput() { + if outI == len(outputs) { + t.Errorf("fewer output in vector than in circuit") + return + } + assignmentRaw = outputs[outI] + outI++ + } + + if assignmentRaw != nil { + var wireAssignment []big.Int + wireAssignment, err := utils.SliceToBigIntSlice(assignmentRaw) + assert.NoError(err) + fullAssignment[w] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + inOutAssignment[w] = sumcheck.NativeMultilinear(utils.ConvertToBigIntSlice(wireAssignment)) + } + } + + fullAssignment.Complete(c, target) + + for _, w := range sorted { + if w.IsOutput() { + + if err = utils.SliceEqualsBigInt(sumcheck.DereferenceBigIntSlice(inOutAssignment[w]), sumcheck.DereferenceBigIntSlice(fullAssignment[w])); err != nil { + t.Errorf("assignment mismatch: %v", err) + } + + } + } + + t.Log("Circuit evaluation complete") + proof, err := Prove(current, target, c, fullAssignment, fiatshamir.WithHashBigInt(hash)) + assert.NoError(err) + t.Log("Proof complete") + + proofEmulated := make(Proofs[FR], len(proof)) + for i, proof := range proof { + proofEmulated[i] = sumcheck.ValueOfProof[FR](proof) + } + + validCircuit := &ProjAddGkrVerifierCircuit[FR]{ + Circuit: cEmulated, + Input: make([][]emulated.Element[FR], len(inputs)), + Output: make([][]emulated.Element[FR], len(outputs)), + SerializedProof: proofEmulated.Serialize(), + } + + validAssignment := &ProjAddGkrVerifierCircuit[FR]{ + Circuit: cEmulated, + Input: make([][]emulated.Element[FR], len(inputs)), + Output: make([][]emulated.Element[FR], len(outputs)), + SerializedProof: proofEmulated.Serialize(), + } + + for i := range inputs { + validCircuit.Input[i] = make([]emulated.Element[FR], len(inputs[i])) + validAssignment.Input[i] = make([]emulated.Element[FR], len(inputs[i])) + for j := range inputs[i] { + validAssignment.Input[i][j] = emulated.ValueOf[FR](inputs[i][j]) + } + } + + for i := range outputs { + validCircuit.Output[i] = make([]emulated.Element[FR], len(outputs[i])) + validAssignment.Output[i] = make([]emulated.Element[FR], len(outputs[i])) + for j := range outputs[i] { + validAssignment.Output[i][j] = emulated.ValueOf[FR](outputs[i][j]) + } + } + + err = test.IsSolved(validCircuit, validAssignment, current) + assert.NoError(err) +} + +func TestMultipleDblAddSelectGKR(t *testing.T) { + var P bn254.G1Affine + var Q bn254.G1Affine + var U bn254.G1Affine + var one fpbn254.Element + one.SetOne() + var zero fpbn254.Element + zero.SetZero() + + var s frbn254.Element + s.SetOne() + var r frbn254.Element + r.SetOne() + P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) + Q.ScalarMultiplicationBase(r.BigInt(new(big.Int))) + U.Add(&P, &Q) + + result, err := new(big.Int).SetString("21888242871839275222246405745257275088696311157297823662689037894645226206973", 10) + if !err { + panic("error result") + } + + var fp emparams.BN254Fp + testMultipleDblAddSelectGKRInstance[emparams.BN254Fp](t, ecc.BN254.ScalarField(), fp.Modulus(), [][]*big.Int{{ElementToBigInt(P.X), ElementToBigInt(P.X)}, {ElementToBigInt(P.Y), ElementToBigInt(P.Y)}, {ElementToBigInt(one), ElementToBigInt(one)}, {ElementToBigInt(zero), ElementToBigInt(zero)}, {ElementToBigInt(one), ElementToBigInt(one)}, {ElementToBigInt(zero), ElementToBigInt(zero)}, {ElementToBigInt(one), ElementToBigInt(one)}}, [][]*big.Int{{result, result}}) +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/mimc_five_levels.json b/std/recursion/gkr/test_vectors/resources/mimc_five_levels.json new file mode 100644 index 0000000000..3dd74f42b5 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/mimc_five_levels.json @@ -0,0 +1,36 @@ +[ + [ + { + "gate": "mimc", + "inputs": [[1,0], [5,5]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[2,0], [5,4]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[3,0], [5,3]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[4,0], [5,2]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[5,0], [5,1]] + } + ], + [ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} + ] +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_identity_gate.json b/std/recursion/gkr/test_vectors/resources/single_identity_gate.json new file mode 100644 index 0000000000..a44066c7b4 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_identity_gate.json @@ -0,0 +1,10 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_input_two_identity_gates.json b/std/recursion/gkr/test_vectors/resources/single_input_two_identity_gates.json new file mode 100644 index 0000000000..6181784fa8 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_input_two_identity_gates.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_input_two_outs.json b/std/recursion/gkr/test_vectors/resources/single_input_two_outs.json new file mode 100644 index 0000000000..c577c1cace --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_input_two_outs.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_mimc_gate.json b/std/recursion/gkr/test_vectors/resources/single_mimc_gate.json new file mode 100644 index 0000000000..c89e7d52ae --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_mimc_gate.json @@ -0,0 +1,7 @@ +[ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + { + "gate": "mimc", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_mimc_gate_two_instances.json b/std/recursion/gkr/test_vectors/resources/single_mimc_gate_two_instances.json new file mode 100644 index 0000000000..a75ccccfef --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_mimc_gate_two_instances.json @@ -0,0 +1,89 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ] + ], + "output": [ + [ + 128, + 2187 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + }, + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 3445061460418080392, + 1582772968760438233, + 15430626802533927355, + 10677110232782539588 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/single_mul_gate.json b/std/recursion/gkr/test_vectors/resources/single_mul_gate.json new file mode 100644 index 0000000000..0f65a07edf --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/single_mul_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json b/std/recursion/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json new file mode 100644 index 0000000000..26681c2f89 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json new file mode 100644 index 0000000000..cdbdb3b471 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "select-input-3", + "inputs": [0,0,1] + } +] \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate_two_instances.json b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate_two_instances.json new file mode 100644 index 0000000000..05a2a421e4 --- /dev/null +++ b/std/recursion/gkr/test_vectors/resources/two_inputs_select-input-3_gate_two_instances.json @@ -0,0 +1,65 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_inputs_select-input-3_gate.json", + "input": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692424 + ], + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692426 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 811536341909029034, + 3521849244771172087, + 18082796152498666864, + 12957593377148931088 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_identity_gate_two_instances.json b/std/recursion/gkr/test_vectors/single_identity_gate_two_instances.json new file mode 100644 index 0000000000..420584f6fa --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_identity_gate_two_instances.json @@ -0,0 +1,51 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_identity_gate.json", + "input": [ + [ + 4, + 3 + ] + ], + "output": [ + [ + 4, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 3445061460418080392, + 1582772968760438233, + 15430626802533927355, + 10677110232782539588 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 405768170954514517, + 1760924622385586043, + 18264770113104109240, + 6478796688574465544 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/std/recursion/gkr/test_vectors/single_input_two_identity_gates_two_instances.json new file mode 100644 index 0000000000..1cf156c016 --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_input_two_identity_gates_two_instances.json @@ -0,0 +1,96 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_identity_gates.json", + "input": [ + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ], + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 1309801114600745759, + 3758563846819454073, + 10262009230221415359, + 16005847429194593330 + ], + [ + 1641562985788773784, + 10408495378109679862, + 1607731544356410364, + 2789460758528902269 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692426 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 811536341909029034, + 3521849244771172087, + 18082796152498666864, + 12957593377148931088 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 34793283102800716, + 14623004755582362860, + 7566020917664053271, + 804411355194692426 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 202884085477257258, + 10103834348047568829, + 18355757093406830428, + 3239398344287232773 + ], + [ + 811536341909029034, + 3521849244771172087, + 18082796152498666864, + 12957593377148931088 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_input_two_outs_two_instances.json b/std/recursion/gkr/test_vectors/single_input_two_outs_two_instances.json new file mode 100644 index 0000000000..9f9bb7b4e4 --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -0,0 +1,102 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_outs.json", + "input": [ + [ + 1, + 2 + ] + ], + "output": [ + [ + 1, + 4 + ], + [ + 1, + 2 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [ + [ + [ + 2241063740747277757, + 6741806107110303462, + 10584378630379443447, + 11431840297086248935 + ], + [ + 1284308761996894303, + 17461779615671711157, + 12779565606756425632, + 11770813148743782171 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 11552014468118848, + 6459316162880666778, + 5573794085540653091, + 12018926454163338051 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 270512113969676344, + 13471779130730091773, + 6027598717499555621, + 10468112483619494236 + ], + [ + 1825956769295315326, + 17147532837589913005, + 17627861250985060925, + 10707841024875543332 + ], + [ + 1923244012590556228, + 16346717705102969708, + 17401129836965471330, + 2115447990305160649 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 38772122160298693, + 2654177376158557373, + 666365361690475594, + 9065178994946100760 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 135256056984838172, + 6735889565365045886, + 12237171395604553618, + 14457428278664522926 + ], + [ + 608652256431771775, + 11864758970433154873, + 18173783132801388052, + 9718195032861698316 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/single_mul_gate_two_instances.json b/std/recursion/gkr/test_vectors/single_mul_gate_two_instances.json new file mode 100644 index 0000000000..128b57f3e1 --- /dev/null +++ b/std/recursion/gkr/test_vectors/single_mul_gate_two_instances.json @@ -0,0 +1,71 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mul_gate.json", + "input": [ + [ + 4, + 3 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 8, + 9 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 3474249952014841962, + 12028090948092229382, + 15144988130097378949, + 4865233403516270609 + ], + [ + 12748314788128703, + 1253101003182465366, + 14218880088090055687, + 17914127541472937276 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 608652256431771775, + 11864758970433154873, + 18173783132801388052, + 9718195032861698319 + ], + [ + 1623072683818058068, + 7043698489542344175, + 17718848231287782113, + 7468442680588310560 + ], + [ + 1690700712310477154, + 10411643272224867119, + 5390689855380507306, + 14697156819920572021 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/std/recursion/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json new file mode 100644 index 0000000000..376025e4e9 --- /dev/null +++ b/std/recursion/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -0,0 +1,77 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_identity_gates_composed_single_input.json", + "input": [ + [ + 2, + 1 + ] + ], + "output": [ + [ + 2, + 1 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "roundPolyEvaluations": [] + }, + { + "finalEvalProof": [ + [ + 3479106886554451955, + 541048341316977072, + 10578437981560588015, + 16173759560562137918 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 21302570947489481, + 677004288128798096, + 11618204988248521184, + 10639673014910314290 + ], + [ + 0, + 0, + 0, + 0 + ] + ] + ] + }, + { + "finalEvalProof": [ + [ + 3465695695855481184, + 12604187663145896652, + 17745663229938913452, + 12139687930078893591 + ] + ], + "roundPolyEvaluations": [ + [ + [ + 67628028492419086, + 3367944782682522943, + 6118585697802276809, + 7228714139332261463 + ], + [ + 0, + 0, + 0, + 0 + ] + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/gkr/utils/util.go b/std/recursion/gkr/utils/util.go new file mode 100644 index 0000000000..c7a9399d1d --- /dev/null +++ b/std/recursion/gkr/utils/util.go @@ -0,0 +1,195 @@ +package utils + +import ( + "fmt" + gohash "hash" + "math/big" + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/math/emulated" + "github.com/stretchr/testify/assert" +) + +func SliceToBigIntSlice[T any](slice []T) ([]big.Int, error) { + elementSlice := make([]big.Int, len(slice)) + for i, v := range slice { + switch v := any(v).(type) { + case *big.Int: + elementSlice[i] = *v + case float64: + elementSlice[i] = *big.NewInt(int64(v)) + default: + return nil, fmt.Errorf("unsupported type: %T", v) + } + } + return elementSlice, nil +} + +func ConvertToBigIntSlice(input []big.Int) []*big.Int { + output := make([]*big.Int, len(input)) + for i := range input { + output[i] = &input[i] + } + return output +} + +func SliceEqualsBigInt(a []big.Int, b []big.Int) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if a[i].Cmp(&b[i]) != 0 { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func ToVariableFr[FR emulated.FieldParams](v interface{}) emulated.Element[FR] { + switch vT := v.(type) { + case float64: + return *new(emulated.Field[FR]).NewElement(int(vT)) + default: + return *new(emulated.Field[FR]).NewElement(v) + } +} + +func ToVariableSliceFr[FR emulated.FieldParams, V any](slice []V) (variableSlice []emulated.Element[FR]) { + variableSlice = make([]emulated.Element[FR], len(slice)) + for i := range slice { + variableSlice[i] = ToVariableFr[FR](slice[i]) + } + return +} + +func ToVariableSliceSliceFr[FR emulated.FieldParams, V any](sliceSlice [][]V) (variableSliceSlice [][]emulated.Element[FR]) { + variableSliceSlice = make([][]emulated.Element[FR], len(sliceSlice)) + for i := range sliceSlice { + variableSliceSlice[i] = ToVariableSliceFr[FR](sliceSlice[i]) + } + return +} + +func AssertSliceEqual[T comparable](t *testing.T, expected, seen []T) { + assert.Equal(t, len(expected), len(seen)) + for i := range seen { + assert.True(t, expected[i] == seen[i], "@%d: %v != %v", i, expected[i], seen[i]) // assert.Equal is not strict enough when comparing pointers, i.e. it compares what they refer to + } +} + +func SliceEqual[T comparable](expected, seen []T) bool { + if len(expected) != len(seen) { + return false + } + for i := range seen { + if expected[i] != seen[i] { + return false + } + } + return true +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (gohash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + var temp big.Int + inputBlockSize := (len(p)-1)/len(temp.Bytes()) + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + var temp big.Int + inputBlockSize := (len(b)-1)/len(temp.Bytes()) + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res big.Int + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + var temp big.Int + return len(temp.Bytes()) +} + +func (m *MessageCounter) BlockSize() int { + var temp big.Int + return len(temp.Bytes()) +} + +func NewMessageCounter(startState, step int) gohash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() gohash.Hash { + return func() gohash.Hash { + return NewMessageCounter(startState, step) + } +} + +type MessageCounterEmulated struct { + startState int64 + state int64 + step int64 + + // cheap trick to avoid unconstrained input errors + api frontend.API + zero frontend.Variable +} + +func (m *MessageCounterEmulated) Write(data ...frontend.Variable) { + + for i := range data { + sq1, sq2 := m.api.Mul(data[i], data[i]), m.api.Mul(data[i], data[i]) + m.zero = m.api.Sub(sq1, sq2, m.zero) + } + + m.state += int64(len(data)) * m.step +} + +func (m *MessageCounterEmulated) Sum() frontend.Variable { + return m.api.Add(m.state, m.zero) +} + +func (m *MessageCounterEmulated) Reset() { + m.zero = 0 + m.state = m.startState +} + +func NewMessageCounterEmulated(api frontend.API, startState, step int) hash.FieldHasher { + transcript := &MessageCounterEmulated{startState: int64(startState), state: int64(startState), step: int64(step), api: api} + return transcript +} + +func NewMessageCounterGeneratorEmulated(startState, step int) func(frontend.API) hash.FieldHasher { + return func(api frontend.API) hash.FieldHasher { + return NewMessageCounterEmulated(api, startState, step) + } +} diff --git a/std/recursion/sumcheck/arithengine.go b/std/recursion/sumcheck/arithengine.go index e4de69ba0a..2c2fedb28b 100644 --- a/std/recursion/sumcheck/arithengine.go +++ b/std/recursion/sumcheck/arithengine.go @@ -15,7 +15,7 @@ type element any // case of prover, it is initialized with a finite field arithmetic engine // defined over [*big.Int] or field arithmetic packages. In case of verifier, is // initialized with non-native arithmetic. -type arithEngine[E element] interface { +type ArithEngine[E element] interface { Add(a, b E) E Mul(a, b E) E Sub(a, b E) E @@ -24,74 +24,74 @@ type arithEngine[E element] interface { Const(i *big.Int) E } -// bigIntEngine performs computation reducing with given modulus. -type bigIntEngine struct { +// BigIntEngine performs computation reducing with given modulus. +type BigIntEngine struct { mod *big.Int // TODO: we should also add pools for more efficient memory management. } -func (be *bigIntEngine) Add(a, b *big.Int) *big.Int { +func (be *BigIntEngine) Add(a, b *big.Int) *big.Int { dst := new(big.Int) dst.Add(a, b) dst.Mod(dst, be.mod) return dst } -func (be *bigIntEngine) Mul(a, b *big.Int) *big.Int { +func (be *BigIntEngine) Mul(a, b *big.Int) *big.Int { dst := new(big.Int) dst.Mul(a, b) dst.Mod(dst, be.mod) return dst } -func (be *bigIntEngine) Sub(a, b *big.Int) *big.Int { +func (be *BigIntEngine) Sub(a, b *big.Int) *big.Int { dst := new(big.Int) dst.Sub(a, b) dst.Mod(dst, be.mod) return dst } -func (be *bigIntEngine) One() *big.Int { +func (be *BigIntEngine) One() *big.Int { return big.NewInt(1) } -func (be *bigIntEngine) Const(i *big.Int) *big.Int { +func (be *BigIntEngine) Const(i *big.Int) *big.Int { return new(big.Int).Set(i) } -func newBigIntEngine(mod *big.Int) *bigIntEngine { - return &bigIntEngine{mod: new(big.Int).Set(mod)} +func NewBigIntEngine(mod *big.Int) *BigIntEngine { + return &BigIntEngine{mod: new(big.Int).Set(mod)} } -// emuEngine uses non-native arithmetic for operations. -type emuEngine[FR emulated.FieldParams] struct { +// EmuEngine uses non-native arithmetic for operations. +type EmuEngine[FR emulated.FieldParams] struct { f *emulated.Field[FR] } -func (ee *emuEngine[FR]) Add(a, b *emulated.Element[FR]) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Add(a, b *emulated.Element[FR]) *emulated.Element[FR] { return ee.f.Add(a, b) } -func (ee *emuEngine[FR]) Mul(a, b *emulated.Element[FR]) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Mul(a, b *emulated.Element[FR]) *emulated.Element[FR] { return ee.f.Mul(a, b) } -func (ee *emuEngine[FR]) Sub(a, b *emulated.Element[FR]) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Sub(a, b *emulated.Element[FR]) *emulated.Element[FR] { return ee.f.Sub(a, b) } -func (ee *emuEngine[FR]) One() *emulated.Element[FR] { +func (ee *EmuEngine[FR]) One() *emulated.Element[FR] { return ee.f.One() } -func (ee *emuEngine[FR]) Const(i *big.Int) *emulated.Element[FR] { +func (ee *EmuEngine[FR]) Const(i *big.Int) *emulated.Element[FR] { return ee.f.NewElement(i) } -func newEmulatedEngine[FR emulated.FieldParams](api frontend.API) (*emuEngine[FR], error) { +func NewEmulatedEngine[FR emulated.FieldParams](api frontend.API) (*EmuEngine[FR], error) { f, err := emulated.NewField[FR](api) if err != nil { return nil, fmt.Errorf("new field: %w", err) } - return &emuEngine[FR]{f: f}, nil + return &EmuEngine[FR]{f: f}, nil } diff --git a/std/recursion/sumcheck/challenge.go b/std/recursion/sumcheck/challenge.go index fb9e87ee4c..3a8759e346 100644 --- a/std/recursion/sumcheck/challenge.go +++ b/std/recursion/sumcheck/challenge.go @@ -25,7 +25,7 @@ func getChallengeNames(prefix string, nbClaims int, nbVars int) []string { } // bindChallengeProver binds the values for challengeName using native Fiat-Shamir transcript. -func bindChallengeProver(fs *cryptofiatshamir.Transcript, challengeName string, values []*big.Int) error { +func BindChallengeProver(fs *cryptofiatshamir.Transcript, challengeName string, values []*big.Int) error { for i := range values { buf := make([]byte, 32) values[i].FillBytes(buf) @@ -39,8 +39,8 @@ func bindChallengeProver(fs *cryptofiatshamir.Transcript, challengeName string, // deriveChallengeProver binds the values for challengeName and then returns the // challenge using native Fiat-Shamir transcript. It also returns the rest of // the challenge names for used in the protocol. -func deriveChallengeProver(fs *cryptofiatshamir.Transcript, challengeNames []string, values []*big.Int) (challenge *big.Int, restChallengeNames []string, err error) { - if err = bindChallengeProver(fs, challengeNames[0], values); err != nil { +func DeriveChallengeProver(fs *cryptofiatshamir.Transcript, challengeNames []string, values []*big.Int) (challenge *big.Int, restChallengeNames []string, err error) { + if err = BindChallengeProver(fs, challengeNames[0], values); err != nil { return nil, nil, fmt.Errorf("bind: %w", err) } nativeChallenge, err := fs.ComputeChallenge(challengeNames[0]) @@ -51,6 +51,7 @@ func deriveChallengeProver(fs *cryptofiatshamir.Transcript, challengeNames []str return challenge, challengeNames[1:], nil } +// todo change this bind as limbs instead of bits, ask @arya if necessary // bindChallenge binds the values for challengeName using in-circuit Fiat-Shamir transcript. func (v *Verifier[FR]) bindChallenge(fs *fiatshamir.Transcript, challengeName string, values []emulated.Element[FR]) error { for i := range values { diff --git a/std/recursion/sumcheck/claim_intf.go b/std/recursion/sumcheck/claim_intf.go index d2df83aea6..731234debd 100644 --- a/std/recursion/sumcheck/claim_intf.go +++ b/std/recursion/sumcheck/claim_intf.go @@ -28,12 +28,12 @@ type claims interface { NbVars() int // Combine combines separate claims into a single sumcheckable claim using // the coefficient coeff. - Combine(coeff *big.Int) nativePolynomial + Combine(coeff *big.Int) NativePolynomial // Next fixes the next free variable to r, keeps the next variable free and // sums over a hypercube for the last variables. Instead of returning the // polynomial in coefficient form, it returns the evaluations at degree // different points. - Next(r *big.Int) nativePolynomial + Next(r *big.Int) NativePolynomial // ProverFinalEval returns the (lazy) evaluation proof. - ProverFinalEval(r []*big.Int) nativeEvaluationProof + ProverFinalEval(r []*big.Int) NativeEvaluationProof } diff --git a/std/recursion/sumcheck/claimable_gate.go b/std/recursion/sumcheck/claimable_gate.go index 04884388ee..ad2a3d3a45 100644 --- a/std/recursion/sumcheck/claimable_gate.go +++ b/std/recursion/sumcheck/claimable_gate.go @@ -11,7 +11,7 @@ import ( ) // gate defines a multivariate polynomial which can be sumchecked. -type gate[AE arithEngine[E], E element] interface { +type gate[AE ArithEngine[E], E element] interface { // NbInputs is the number of inputs the gate takes. NbInputs() int // Evaluate evaluates the gate at inputs vars. @@ -27,9 +27,9 @@ type gate[AE arithEngine[E], E element] interface { type gateClaim[FR emulated.FieldParams] struct { f *emulated.Field[FR] p *polynomial.Polynomial[FR] - engine *emuEngine[FR] + engine *EmuEngine[FR] - gate gate[*emuEngine[FR], *emulated.Element[FR]] + gate gate[*EmuEngine[FR], *emulated.Element[FR]] evaluationPoints [][]*emulated.Element[FR] claimedEvaluations []*emulated.Element[FR] @@ -48,7 +48,7 @@ type gateClaim[FR emulated.FieldParams] struct { // evaluationPoints is the random coefficients for ensuring the consistency of // the inputs during the final round and claimedEvals is the claimed evaluation // values with the inputs combined at the evaluationPoints. -func newGate[FR emulated.FieldParams](api frontend.API, gate gate[*emuEngine[FR], *emulated.Element[FR]], +func newGate[FR emulated.FieldParams](api frontend.API, gate gate[*EmuEngine[FR], *emulated.Element[FR]], inputs [][]*emulated.Element[FR], evaluationPoints [][]*emulated.Element[FR], claimedEvals []*emulated.Element[FR]) (LazyClaims[FR], error) { nbInputs := gate.NbInputs() @@ -71,7 +71,7 @@ func newGate[FR emulated.FieldParams](api frontend.API, gate gate[*emuEngine[FR] if err != nil { return nil, fmt.Errorf("new polynomial: %w", err) } - engine, err := newEmulatedEngine[FR](api) + engine, err := NewEmulatedEngine[FR](api) if err != nil { return nil, fmt.Errorf("new emulated engine: %w", err) } @@ -152,9 +152,9 @@ func (g *gateClaim[FR]) AssertEvaluation(r []*emulated.Element[FR], combinationC } type nativeGateClaim struct { - engine *bigIntEngine + engine *BigIntEngine - gate gate[*bigIntEngine, *big.Int] + gate gate[*BigIntEngine, *big.Int] evaluationPoints [][]*big.Int claimedEvaluations []*big.Int @@ -163,13 +163,13 @@ type nativeGateClaim struct { // multi-instance input id to the instance value. This allows running // sumcheck over the hypercube. Every element in the slice represents the // input. - inputPreprocessors []nativeMultilinear + inputPreprocessors []NativeMultilinear - eq nativeMultilinear + eq NativeMultilinear } -func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [][]*big.Int, evaluationPoints [][]*big.Int) (claim claims, evaluations []*big.Int, err error) { - be := newBigIntEngine(target) +func newNativeGate(target *big.Int, gate gate[*BigIntEngine, *big.Int], inputs [][]*big.Int, evaluationPoints [][]*big.Int) (claim claims, evaluations []*big.Int, err error) { + be := &BigIntEngine{mod: new(big.Int).Set(target)} nbInputs := gate.NbInputs() if len(inputs) != nbInputs { return nil, nil, fmt.Errorf("expected %d inputs got %d", nbInputs, len(inputs)) @@ -184,7 +184,7 @@ func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [ evalInput := make([][]*big.Int, nbInstances) // TODO: pad input to power of two for i := range evalInput { - evalInput[i] = make(nativeMultilinear, nbInputs) + evalInput[i] = make(NativeMultilinear, nbInputs) for j := range evalInput[i] { evalInput[i][j] = new(big.Int).Set(inputs[j][i]) } @@ -196,9 +196,9 @@ func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [ evaluations[i] = gate.Evaluate(be, evalInput[i]...) } // construct the mapping (inputIdx, instanceIdx) -> inputVal - inputPreprocessors := make([]nativeMultilinear, nbInputs) + inputPreprocessors := make([]NativeMultilinear, nbInputs) for i := range inputs { - inputPreprocessors[i] = make(nativeMultilinear, nbInstances) + inputPreprocessors[i] = make(NativeMultilinear, nbInstances) for j := range inputs[i] { inputPreprocessors[i][j] = new(big.Int).Set(inputs[i][j]) } @@ -211,7 +211,7 @@ func newNativeGate(target *big.Int, gate gate[*bigIntEngine, *big.Int], inputs [ // compute the random linear combinations of the evaluation values of the gate claimedEvaluations := make([]*big.Int, len(evaluationPoints)) for i := range claimedEvaluations { - claimedEvaluations[i] = eval(be, evaluations, evaluationPoints[i]) + claimedEvaluations[i] = Eval(be, evaluations, evaluationPoints[i]) } return &nativeGateClaim{ engine: be, @@ -231,19 +231,19 @@ func (g *nativeGateClaim) NbVars() int { return len(g.evaluationPoints[0]) } -func (g *nativeGateClaim) Combine(coeff *big.Int) nativePolynomial { +func (g *nativeGateClaim) Combine(coeff *big.Int) NativePolynomial { nbVars := g.NbVars() eqLength := 1 << nbVars nbClaims := g.NbClaims() - g.eq = make(nativeMultilinear, eqLength) + g.eq = make(NativeMultilinear, eqLength) g.eq[0] = g.engine.One() for i := 1; i < eqLength; i++ { g.eq[i] = new(big.Int) } - g.eq = eq(g.engine, g.eq, g.evaluationPoints[0]) + g.eq = Eq(g.engine, g.eq, g.evaluationPoints[0]) - newEq := make(nativeMultilinear, eqLength) + newEq := make(NativeMultilinear, eqLength) for i := 1; i < eqLength; i++ { newEq[i] = new(big.Int) } @@ -251,7 +251,7 @@ func (g *nativeGateClaim) Combine(coeff *big.Int) nativePolynomial { for k := 1; k < nbClaims; k++ { newEq[0] = g.engine.One() - g.eq = eqAcc(g.engine, g.eq, newEq, g.evaluationPoints[k]) + g.eq = EqAcc(g.engine, g.eq, newEq, g.evaluationPoints[k]) if k+1 < nbClaims { aI = g.engine.Mul(aI, coeff) } @@ -260,32 +260,32 @@ func (g *nativeGateClaim) Combine(coeff *big.Int) nativePolynomial { return g.computeGJ() } -func (g *nativeGateClaim) Next(r *big.Int) nativePolynomial { +func (g *nativeGateClaim) Next(r *big.Int) NativePolynomial { for i := range g.inputPreprocessors { - g.inputPreprocessors[i] = fold(g.engine, g.inputPreprocessors[i], r) + g.inputPreprocessors[i] = Fold(g.engine, g.inputPreprocessors[i], r) } - g.eq = fold(g.engine, g.eq, r) + g.eq = Fold(g.engine, g.eq, r) return g.computeGJ() } -func (g *nativeGateClaim) ProverFinalEval(r []*big.Int) nativeEvaluationProof { +func (g *nativeGateClaim) ProverFinalEval(r []*big.Int) NativeEvaluationProof { // verifier computes the value of the gate (times the eq) itself return nil } -func (g *nativeGateClaim) computeGJ() nativePolynomial { +func (g *nativeGateClaim) computeGJ() NativePolynomial { // returns the polynomial GJ through its evaluations degGJ := 1 + g.gate.Degree() nbGateIn := len(g.inputPreprocessors) - s := make([]nativeMultilinear, nbGateIn+1) + s := make([]NativeMultilinear, nbGateIn+1) s[0] = g.eq copy(s[1:], g.inputPreprocessors) nbInner := len(s) nbOuter := len(s[0]) / 2 - gJ := make(nativePolynomial, degGJ) + gJ := make(NativePolynomial, degGJ) for i := range gJ { gJ[i] = new(big.Int) } diff --git a/std/recursion/sumcheck/claimable_multilinear.go b/std/recursion/sumcheck/claimable_multilinear.go index c73395514f..7bb4b43918 100644 --- a/std/recursion/sumcheck/claimable_multilinear.go +++ b/std/recursion/sumcheck/claimable_multilinear.go @@ -62,7 +62,7 @@ func (fn *multilinearClaim[FR]) AssertEvaluation(r []*emulated.Element[FR], comb } type nativeMultilinearClaim struct { - be *bigIntEngine + be *BigIntEngine ml []*big.Int } @@ -71,7 +71,7 @@ func newNativeMultilinearClaim(target *big.Int, ml []*big.Int) (claim claims, hy if bits.OnesCount(uint(len(ml))) != 1 { return nil, nil, fmt.Errorf("expecting power of two coeffs") } - be := newBigIntEngine(target) + be := NewBigIntEngine(target) hypersum = new(big.Int) for i := range ml { hypersum = be.Add(hypersum, ml[i]) @@ -91,16 +91,16 @@ func (fn *nativeMultilinearClaim) NbVars() int { return bits.Len(uint(len(fn.ml))) - 1 } -func (fn *nativeMultilinearClaim) Combine(coeff *big.Int) nativePolynomial { +func (fn *nativeMultilinearClaim) Combine(coeff *big.Int) NativePolynomial { return []*big.Int{hypersumX1One(fn.be, fn.ml)} } -func (fn *nativeMultilinearClaim) Next(r *big.Int) nativePolynomial { - fn.ml = fold(fn.be, fn.ml, r) +func (fn *nativeMultilinearClaim) Next(r *big.Int) NativePolynomial { + fn.ml = Fold(fn.be, fn.ml, r) return []*big.Int{hypersumX1One(fn.be, fn.ml)} } -func (fn *nativeMultilinearClaim) ProverFinalEval(r []*big.Int) nativeEvaluationProof { +func (fn *nativeMultilinearClaim) ProverFinalEval(r []*big.Int) NativeEvaluationProof { // verifier computes the value of the multilinear function itself return nil } diff --git a/std/recursion/sumcheck/fullscalarmul_test.go b/std/recursion/sumcheck/fullscalarmul_test.go new file mode 100644 index 0000000000..2bc4052f58 --- /dev/null +++ b/std/recursion/sumcheck/fullscalarmul_test.go @@ -0,0 +1,184 @@ +package sumcheck + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + fr_secp256k1 "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/test" +) + +type ScalarMulCircuit[Base, Scalars emulated.FieldParams] struct { + Points []sw_emulated.AffinePoint[Base] + Scalars []emulated.Element[Scalars] + + nbScalarBits int +} + +func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { + if len(c.Points) != len(c.Scalars) { + return fmt.Errorf("len(inputs) != len(scalars)") + } + baseApi, err := emulated.NewField[B](api) + if err != nil { + return fmt.Errorf("new base field: %w", err) + } + scalarApi, err := emulated.NewField[S](api) + if err != nil { + return fmt.Errorf("new scalar field: %w", err) + } + for i := range c.Points { + step, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i]) + if err != nil { + return fmt.Errorf("hint scalar mul steps: %w", err) + } + _ = step + } + return nil +} + +func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, + baseApi *emulated.Field[B], scalarApi *emulated.Field[S], + nbScalarBits int, + point sw_emulated.AffinePoint[B], scalar emulated.Element[S]) ([][6]*emulated.Element[B], error) { + var fp B + var fr S + inputs := []frontend.Variable{fp.BitsPerLimb(), fp.NbLimbs()} + inputs = append(inputs, baseApi.Modulus().Limbs...) + inputs = append(inputs, point.X.Limbs...) + inputs = append(inputs, point.Y.Limbs...) + inputs = append(inputs, fr.BitsPerLimb(), fr.NbLimbs()) + inputs = append(inputs, scalarApi.Modulus().Limbs...) + inputs = append(inputs, scalar.Limbs...) + nbRes := nbScalarBits * int(fp.NbLimbs()) * 6 + hintRes, err := api.Compiler().NewHint(hintScalarMulSteps, nbRes, inputs...) + if err != nil { + return nil, fmt.Errorf("new hint: %w", err) + } + res := make([][6]*emulated.Element[B], nbScalarBits) + for i := range res { + for j := 0; j < 6; j++ { + limbs := hintRes[i*(6*int(fp.NbLimbs()))+j*int(fp.NbLimbs()) : i*(6*int(fp.NbLimbs()))+(j+1)*int(fp.NbLimbs())] + res[i][j] = baseApi.NewElement(limbs) + } + } + return res, nil +} + +func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + nbBits := int(inputs[0].Int64()) + nbLimbs := int(inputs[1].Int64()) + fpLimbs := inputs[2 : 2+nbLimbs] + xLimbs := inputs[2+nbLimbs : 2+2*nbLimbs] + yLimbs := inputs[2+2*nbLimbs : 2+3*nbLimbs] + nbScalarBits := int(inputs[2+3*nbLimbs].Int64()) + nbScalarLimbs := int(inputs[3+3*nbLimbs].Int64()) + frLimbs := inputs[4+3*nbLimbs : 4+3*nbLimbs+nbScalarLimbs] + scalarLimbs := inputs[4+3*nbLimbs+nbScalarLimbs : 4+3*nbLimbs+2*nbScalarLimbs] + + x := new(big.Int) + y := new(big.Int) + fp := new(big.Int) + fr := new(big.Int) + scalar := new(big.Int) + if err := recompose(fpLimbs, uint(nbBits), fp); err != nil { + return fmt.Errorf("recompose fp: %w", err) + } + if err := recompose(frLimbs, uint(nbScalarBits), fr); err != nil { + return fmt.Errorf("recompose fr: %w", err) + } + if err := recompose(xLimbs, uint(nbBits), x); err != nil { + return fmt.Errorf("recompose x: %w", err) + } + if err := recompose(yLimbs, uint(nbBits), y); err != nil { + return fmt.Errorf("recompose y: %w", err) + } + if err := recompose(scalarLimbs, uint(nbScalarBits), scalar); err != nil { + return fmt.Errorf("recompose scalar: %w", err) + } + fmt.Println(fp, fr, x, y, scalar) + + scalarLength := len(outputs) / (6 * nbLimbs) + println("scalarLength", scalarLength) + return nil +} + +func recompose(inputs []*big.Int, nbBits uint, res *big.Int) error { + if len(inputs) == 0 { + return fmt.Errorf("zero length slice input") + } + if res == nil { + return fmt.Errorf("result not initialized") + } + res.SetUint64(0) + for i := range inputs { + res.Lsh(res, nbBits) + res.Add(res, inputs[len(inputs)-i-1]) + } + // TODO @gbotrel mod reduce ? + return nil +} + +func decompose(input *big.Int, nbBits uint, res []*big.Int) error { + // limb modulus + if input.BitLen() > len(res)*int(nbBits) { + return fmt.Errorf("decomposed integer does not fit into res") + } + for _, r := range res { + if r == nil { + return fmt.Errorf("result slice element uninitalized") + } + } + base := new(big.Int).Lsh(big.NewInt(1), nbBits) + tmp := new(big.Int).Set(input) + for i := 0; i < len(res); i++ { + res[i].Mod(tmp, base) + tmp.Rsh(tmp, nbBits) + } + return nil +} + +func TestScalarMul(t *testing.T) { + assert := test.NewAssert(t) + type B = emparams.Secp256k1Fp + type S = emparams.Secp256k1Fr + t.Log(B{}.Modulus(), S{}.Modulus()) + var P secp256k1.G1Affine + var s fr_secp256k1.Element + nbInputs := 1 << 0 + nbScalarBits := 2 + scalarBound := new(big.Int).Lsh(big.NewInt(1), uint(nbScalarBits)) + points := make([]sw_emulated.AffinePoint[B], nbInputs) + scalars := make([]emulated.Element[S], nbInputs) + for i := range points { + s.SetRandom() + P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) + sc, _ := rand.Int(rand.Reader, scalarBound) + t.Log(P.X.String(), P.Y.String(), sc.String()) + points[i] = sw_emulated.AffinePoint[B]{ + X: emulated.ValueOf[B](P.X), + Y: emulated.ValueOf[B](P.Y), + } + scalars[i] = emulated.ValueOf[S](sc) + } + circuit := ScalarMulCircuit[B, S]{ + Points: make([]sw_emulated.AffinePoint[B], nbInputs), + Scalars: make([]emulated.Element[S], nbInputs), + nbScalarBits: nbScalarBits, + } + witness := ScalarMulCircuit[B, S]{ + Points: points, + Scalars: scalars, + } + err := test.IsSolved(&circuit, &witness, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} \ No newline at end of file diff --git a/std/recursion/sumcheck/polynomial.go b/std/recursion/sumcheck/polynomial.go index aaeb318fe4..3e0da31a38 100644 --- a/std/recursion/sumcheck/polynomial.go +++ b/std/recursion/sumcheck/polynomial.go @@ -2,14 +2,48 @@ package sumcheck import ( "math/big" + "math/bits" ) -type nativePolynomial []*big.Int -type nativeMultilinear []*big.Int +type NativePolynomial []*big.Int +type NativeMultilinear []*big.Int // helper functions for multilinear polynomial evaluations -func fold(api *bigIntEngine, ml nativeMultilinear, r *big.Int) nativeMultilinear { +// Clone returns a deep copy of p. +// If capacity is provided, the new coefficient slice capacity will be set accordingly. +func (p NativeMultilinear) Clone(capacity ...int) NativeMultilinear { + var newCapacity int + if len(capacity) > 0 { + newCapacity = capacity[0] + } else { + newCapacity = len(p) + } + + res := make(NativeMultilinear, len(p), newCapacity) + for i, v := range p { + res[i] = new(big.Int).Set(v) + } + return res +} + +func DereferenceBigIntSlice(ptrs []*big.Int) []big.Int { + vals := make([]big.Int, len(ptrs)) + for i, ptr := range ptrs { + vals[i] = *ptr + } + return vals +} + +func ReferenceBigIntSlice(vals []big.Int) []*big.Int { + ptrs := make([]*big.Int, len(vals)) + for i := range ptrs { + ptrs[i] = &vals[i] + } + return ptrs +} + +func Fold(api *BigIntEngine, ml NativeMultilinear, r *big.Int) NativeMultilinear { // NB! it modifies ml in-place and also returns mid := len(ml) / 2 bottom, top := ml[:mid], ml[mid:] @@ -22,7 +56,7 @@ func fold(api *bigIntEngine, ml nativeMultilinear, r *big.Int) nativeMultilinear return ml[:mid] } -func hypersumX1One(api *bigIntEngine, ml nativeMultilinear) *big.Int { +func hypersumX1One(api *BigIntEngine, ml NativeMultilinear) *big.Int { sum := ml[len(ml)/2] for i := len(ml)/2 + 1; i < len(ml); i++ { sum = api.Add(sum, ml[i]) @@ -30,7 +64,7 @@ func hypersumX1One(api *bigIntEngine, ml nativeMultilinear) *big.Int { return sum } -func eq(api *bigIntEngine, ml nativeMultilinear, q []*big.Int) nativeMultilinear { +func Eq(api *BigIntEngine, ml NativeMultilinear, q []*big.Int) NativeMultilinear { if (1 << len(q)) != len(ml) { panic("scalar length mismatch") } @@ -46,20 +80,20 @@ func eq(api *bigIntEngine, ml nativeMultilinear, q []*big.Int) nativeMultilinear return ml } -func eval(api *bigIntEngine, ml nativeMultilinear, r []*big.Int) *big.Int { - mlCopy := make(nativeMultilinear, len(ml)) +func Eval(api *BigIntEngine, ml NativeMultilinear, r []*big.Int) *big.Int { + mlCopy := make(NativeMultilinear, len(ml)) for i := range mlCopy { mlCopy[i] = new(big.Int).Set(ml[i]) } for _, ri := range r { - mlCopy = fold(api, mlCopy, ri) + mlCopy = Fold(api, mlCopy, ri) } return mlCopy[0] } -func eqAcc(api *bigIntEngine, e nativeMultilinear, m nativeMultilinear, q []*big.Int) nativeMultilinear { +func EqAcc(api *BigIntEngine, e NativeMultilinear, m NativeMultilinear, q []*big.Int) NativeMultilinear { if len(e) != len(m) { panic("length mismatch") } @@ -83,3 +117,7 @@ func eqAcc(api *bigIntEngine, e nativeMultilinear, m nativeMultilinear, q []*big } return e } + +func (m NativeMultilinear) NumVars() int { + return bits.TrailingZeros(uint(len(m))) +} diff --git a/std/recursion/sumcheck/proof.go b/std/recursion/sumcheck/proof.go index cdba88cc7d..67e28fb7ef 100644 --- a/std/recursion/sumcheck/proof.go +++ b/std/recursion/sumcheck/proof.go @@ -1,6 +1,8 @@ package sumcheck import ( + "math/big" + "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/polynomial" ) @@ -14,9 +16,9 @@ type Proof[FR emulated.FieldParams] struct { FinalEvalProof EvaluationProof } -type nativeProof struct { - RoundPolyEvaluations []nativePolynomial - FinalEvalProof nativeEvaluationProof +type NativeProof struct { + RoundPolyEvaluations []NativePolynomial + FinalEvalProof NativeEvaluationProof } // EvaluationProof is proof for allowing the sumcheck verifier to perform the @@ -27,16 +29,32 @@ type nativeProof struct { // - if it is deferred, then it is a slice. type EvaluationProof any -type nativeEvaluationProof any +// evaluationProof for gkr +type DeferredEvalProof[FR emulated.FieldParams] []emulated.Element[FR] +type NativeDeferredEvalProof []big.Int + +type NativeEvaluationProof any -func valueOfProof[FR emulated.FieldParams](nproof nativeProof) Proof[FR] { +func ValueOfProof[FR emulated.FieldParams](nproof NativeProof) Proof[FR] { rps := make([]polynomial.Univariate[FR], len(nproof.RoundPolyEvaluations)) + finaleval := nproof.FinalEvalProof + if finaleval != nil { + switch v := finaleval.(type) { + case NativeDeferredEvalProof: + deferredEval := make(DeferredEvalProof[FR], len(v)) + for i := range v { + deferredEval[i] = emulated.ValueOf[FR](v[i]) + } + finaleval = deferredEval + } + } for i := range nproof.RoundPolyEvaluations { rps[i] = polynomial.ValueOfUnivariate[FR](nproof.RoundPolyEvaluations[i]) } - // TODO: type switch FinalEvalProof when it is not-nil + return Proof[FR]{ RoundPolyEvaluations: rps, + FinalEvalProof: finaleval, } } diff --git a/std/recursion/sumcheck/prover.go b/std/recursion/sumcheck/prover.go index c075cf1530..d79c467db8 100644 --- a/std/recursion/sumcheck/prover.go +++ b/std/recursion/sumcheck/prover.go @@ -32,8 +32,8 @@ func newProverConfig(opts ...proverOption) (*proverConfig, error) { return ret, nil } -func prove(current *big.Int, target *big.Int, claims claims, opts ...proverOption) (nativeProof, error) { - var proof nativeProof +func Prove(current *big.Int, target *big.Int, claims claims, opts ...proverOption) (NativeProof, error) { + var proof NativeProof cfg, err := newProverConfig(opts...) if err != nil { return proof, fmt.Errorf("parse options: %w", err) @@ -44,46 +44,43 @@ func prove(current *big.Int, target *big.Int, claims claims, opts ...proverOptio return proof, fmt.Errorf("new short hash: %w", err) } fs := fiatshamir.NewTranscript(fshash, challengeNames...) - if err != nil { - return proof, fmt.Errorf("new transcript: %w", err) - } // bind challenge from previous round if it is a continuation - if err = bindChallengeProver(fs, challengeNames[0], cfg.baseChallenges); err != nil { + if err = BindChallengeProver(fs, challengeNames[0], cfg.baseChallenges); err != nil { return proof, fmt.Errorf("base: %w", err) } combinationCoef := big.NewInt(0) if claims.NbClaims() >= 2 { - if combinationCoef, challengeNames, err = deriveChallengeProver(fs, challengeNames, nil); err != nil { + if combinationCoef, challengeNames, err = DeriveChallengeProver(fs, challengeNames, nil); err != nil { return proof, fmt.Errorf("derive combination coef: %w", err) } } // in sumcheck we run a round for every variable. So the number of variables // defines the number of rounds. nbVars := claims.NbVars() - proof.RoundPolyEvaluations = make([]nativePolynomial, nbVars) + proof.RoundPolyEvaluations = make([]NativePolynomial, nbVars) // the first round in the sumcheck is without verifier challenge. Combine challenges and provers sends the first polynomial proof.RoundPolyEvaluations[0] = claims.Combine(combinationCoef) - challenges := make([]*big.Int, nbVars) // we iterate over all variables. However, we omit the last round as the // final evaluation is possibly deferred. for j := 0; j < nbVars-1; j++ { // compute challenge for the next round - if challenges[j], challengeNames, err = deriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[j]); err != nil { + if challenges[j], challengeNames, err = DeriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[j]); err != nil { return proof, fmt.Errorf("derive challenge: %w", err) } // compute the univariate polynomial with first j variables fixed. proof.RoundPolyEvaluations[j+1] = claims.Next(challenges[j]) } - if challenges[nbVars-1], challengeNames, err = deriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[nbVars-1]); err != nil { + if challenges[nbVars-1], challengeNames, err = DeriveChallengeProver(fs, challengeNames, proof.RoundPolyEvaluations[nbVars-1]); err != nil { return proof, fmt.Errorf("derive challenge: %w", err) } if len(challengeNames) > 0 { return proof, fmt.Errorf("excessive challenges") } + proof.FinalEvalProof = claims.ProverFinalEval(challenges) return proof, nil diff --git a/std/recursion/sumcheck/scalarmul_gates_test.go b/std/recursion/sumcheck/scalarmul_gates.go similarity index 88% rename from std/recursion/sumcheck/scalarmul_gates_test.go rename to std/recursion/sumcheck/scalarmul_gates.go index 30ff77e1ad..71ef0207e8 100644 --- a/std/recursion/sumcheck/scalarmul_gates_test.go +++ b/std/recursion/sumcheck/scalarmul_gates.go @@ -14,13 +14,13 @@ import ( "github.com/consensys/gnark/test" ) -type projAddGate[AE arithEngine[E], E element] struct { - folding E +type ProjAddGate[AE ArithEngine[E], E element] struct { + Folding E } -func (m projAddGate[AE, E]) NbInputs() int { return 6 } -func (m projAddGate[AE, E]) Degree() int { return 4 } -func (m projAddGate[AE, E]) Evaluate(api AE, vars ...E) E { +func (m ProjAddGate[AE, E]) NbInputs() int { return 6 } +func (m ProjAddGate[AE, E]) Degree() int { return 4 } +func (m ProjAddGate[AE, E]) Evaluate(api AE, vars ...E) E { if len(vars) != m.NbInputs() { panic("incorrect nb of inputs") } @@ -61,9 +61,9 @@ func (m projAddGate[AE, E]) Evaluate(api AE, vars ...E) E { Z3 = api.Mul(Z3, t4) Z3 = api.Add(Z3, t0) - res := api.Mul(m.folding, Z3) + res := api.Mul(m.Folding, Z3) res = api.Add(res, Y3) - res = api.Mul(m.folding, res) + res = api.Mul(m.Folding, res) res = api.Add(res, X3) return res } @@ -102,7 +102,7 @@ func (c *ProjAddSumcheckCircuit[FR]) Define(api frontend.API) error { for i := range c.EvaluationPoints { evalPoints[i] = polynomial.FromSlice[FR](c.EvaluationPoints[i]) } - claim, err := newGate[FR](api, projAddGate[*emuEngine[FR], *emulated.Element[FR]]{f.NewElement(123)}, inputs, evalPoints, claimedEvals) + claim, err := newGate[FR](api, ProjAddGate[*EmuEngine[FR], *emulated.Element[FR]]{f.NewElement(123)}, inputs, evalPoints, claimedEvals) if err != nil { return fmt.Errorf("new gate claim: %w", err) } @@ -114,7 +114,7 @@ func (c *ProjAddSumcheckCircuit[FR]) Define(api frontend.API) error { func testProjAddSumCheckInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, inputs [][]int) { var fr FR - nativeGate := projAddGate[*bigIntEngine, *big.Int]{folding: big.NewInt(123)} + nativeGate := ProjAddGate[*BigIntEngine, *big.Int]{Folding: big.NewInt(123)} assert := test.NewAssert(t) inputB := make([][]*big.Int, len(inputs)) for i := range inputB { @@ -126,7 +126,7 @@ func testProjAddSumCheckInstance[FR emulated.FieldParams](t *testing.T, current evalPointsB, evalPointsPH, evalPointsC := getChallengeEvaluationPoints[FR](inputB) claim, evals, err := newNativeGate(fr.Modulus(), nativeGate, inputB, evalPointsB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(inputs[0]))) - 1 circuit := &ProjAddSumcheckCircuit[FR]{ @@ -137,7 +137,7 @@ func testProjAddSumCheckInstance[FR emulated.FieldParams](t *testing.T, current } assignment := &ProjAddSumcheckCircuit[FR]{ Inputs: make([][]emulated.Element[FR], len(inputs)), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), EvaluationPoints: evalPointsC, Claimed: []emulated.Element[FR]{emulated.ValueOf[FR](evals[0])}, } @@ -168,11 +168,11 @@ func TestProjAddSumCheckSumcheck(t *testing.T) { testProjAddSumCheckInstance[emparams.BN254Fr](t, ecc.BN254.ScalarField(), inputs) } -type dblAddSelectGate[AE arithEngine[E], E element] struct { - folding []E +type DblAddSelectGate[AE ArithEngine[E], E element] struct { + Folding []E } -func projAdd[AE arithEngine[E], E element](api AE, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { +func projAdd[AE ArithEngine[E], E element](api AE, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { b3 := api.Const(big.NewInt(21)) t0 := api.Mul(X1, X2) t1 := api.Mul(Y1, Y2) @@ -210,7 +210,7 @@ func projAdd[AE arithEngine[E], E element](api AE, X1, Y1, Z1, X2, Y2, Z2 E) (X3 return } -func projSelect[AE arithEngine[E], E element](api AE, selector, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { +func projSelect[AE ArithEngine[E], E element](api AE, selector, X1, Y1, Z1, X2, Y2, Z2 E) (X3, Y3, Z3 E) { X3 = api.Sub(X1, X2) X3 = api.Mul(selector, X3) X3 = api.Add(X3, X2) @@ -225,7 +225,7 @@ func projSelect[AE arithEngine[E], E element](api AE, selector, X1, Y1, Z1, X2, return } -func projDbl[AE arithEngine[E], E element](api AE, X, Y, Z E) (X3, Y3, Z3 E) { +func projDbl[AE ArithEngine[E], E element](api AE, X, Y, Z E) (X3, Y3, Z3 E) { b3 := api.Const(big.NewInt(21)) t0 := api.Mul(Y, Y) Z3 = api.Add(t0, t0) @@ -248,14 +248,14 @@ func projDbl[AE arithEngine[E], E element](api AE, X, Y, Z E) (X3, Y3, Z3 E) { return } -func (m dblAddSelectGate[AE, E]) NbInputs() int { return 7 } -func (m dblAddSelectGate[AE, E]) Degree() int { return 5 } -func (m dblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) E { +func (m DblAddSelectGate[AE, E]) NbInputs() int { return 7 } +func (m DblAddSelectGate[AE, E]) Degree() int { return 5 } +func (m DblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) E { if len(vars) != m.NbInputs() { panic("incorrect nb of inputs") } - if len(m.folding) != m.NbInputs()-1 { - panic("incorrect nb of folding vars") + if len(m.Folding) != m.NbInputs()-1 { + panic("incorrect nb of Folding vars") } // X1, Y1, Z1 == accumulator X1, Y1, Z1 := vars[0], vars[1], vars[2] @@ -267,13 +267,13 @@ func (m dblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) E { ResX, ResY, ResZ := projSelect(api, selector, tmpX, tmpY, tmpZ, X2, Y2, Z2) AccX, AccY, AccZ := projDbl(api, X1, Y1, Z1) - // folding part - f0 := api.Mul(m.folding[0], AccX) - f1 := api.Mul(m.folding[1], AccY) - f2 := api.Mul(m.folding[2], AccZ) - f3 := api.Mul(m.folding[3], ResX) - f4 := api.Mul(m.folding[4], ResY) - f5 := api.Mul(m.folding[5], ResZ) + // Folding part + f0 := api.Mul(m.Folding[0], AccX) + f1 := api.Mul(m.Folding[1], AccY) + f2 := api.Mul(m.Folding[2], AccZ) + f3 := api.Mul(m.Folding[3], ResX) + f4 := api.Mul(m.Folding[4], ResY) + f5 := api.Mul(m.Folding[5], ResZ) res := api.Add(f0, f1) res = api.Add(res, f2) res = api.Add(res, f3) @@ -285,7 +285,7 @@ func (m dblAddSelectGate[AE, E]) Evaluate(api AE, vars ...E) E { func TestDblAndAddGate(t *testing.T) { assert := test.NewAssert(t) - nativeGate := dblAddSelectGate[*bigIntEngine, *big.Int]{folding: []*big.Int{ + nativeGate := DblAddSelectGate[*BigIntEngine, *big.Int]{Folding: []*big.Int{ big.NewInt(1), big.NewInt(2), big.NewInt(3), @@ -299,7 +299,7 @@ func TestDblAndAddGate(t *testing.T) { assert.True(ok) secpfp, ok := new(big.Int).SetString("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16) assert.True(ok) - eng := newBigIntEngine(secpfp) + eng := NewBigIntEngine(secpfp) res := nativeGate.Evaluate(eng, px, py, big.NewInt(1), big.NewInt(0), big.NewInt(1), big.NewInt(0), big.NewInt(1)) t.Log(res) _ = res @@ -339,9 +339,9 @@ func (c *ProjDblAddSelectSumcheckCircuit[FR]) Define(api frontend.API) error { for i := range c.EvaluationPoints { evalPoints[i] = polynomial.FromSlice[FR](c.EvaluationPoints[i]) } - claim, err := newGate[FR](api, dblAddSelectGate[*emuEngine[FR], + claim, err := newGate[FR](api, DblAddSelectGate[*EmuEngine[FR], *emulated.Element[FR]]{ - folding: []*emulated.Element[FR]{ + Folding: []*emulated.Element[FR]{ f.NewElement(1), f.NewElement(2), f.NewElement(3), @@ -361,7 +361,7 @@ func (c *ProjDblAddSelectSumcheckCircuit[FR]) Define(api frontend.API) error { func testProjDblAddSelectSumCheckInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, inputs [][]int) { var fr FR - nativeGate := dblAddSelectGate[*bigIntEngine, *big.Int]{folding: []*big.Int{ + nativeGate := DblAddSelectGate[*BigIntEngine, *big.Int]{Folding: []*big.Int{ big.NewInt(1), big.NewInt(2), big.NewInt(3), @@ -380,7 +380,7 @@ func testProjDblAddSelectSumCheckInstance[FR emulated.FieldParams](t *testing.T, evalPointsB, evalPointsPH, evalPointsC := getChallengeEvaluationPoints[FR](inputB) claim, evals, err := newNativeGate(fr.Modulus(), nativeGate, inputB, evalPointsB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(inputs[0]))) - 1 circuit := &ProjDblAddSelectSumcheckCircuit[FR]{ @@ -391,7 +391,7 @@ func testProjDblAddSelectSumCheckInstance[FR emulated.FieldParams](t *testing.T, } assignment := &ProjDblAddSelectSumcheckCircuit[FR]{ Inputs: make([][]emulated.Element[FR], len(inputs)), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), EvaluationPoints: evalPointsC, Claimed: []emulated.Element[FR]{emulated.ValueOf[FR](evals[0])}, } diff --git a/std/recursion/sumcheck/sumcheck_test.go b/std/recursion/sumcheck/sumcheck.go similarity index 94% rename from std/recursion/sumcheck/sumcheck_test.go rename to std/recursion/sumcheck/sumcheck.go index 1127e46e88..0a19dc8e21 100644 --- a/std/recursion/sumcheck/sumcheck_test.go +++ b/std/recursion/sumcheck/sumcheck.go @@ -46,7 +46,7 @@ func testMultilinearSumcheckInstance[FR emulated.FieldParams](t *testing.T, curr claim, value, err := newNativeMultilinearClaim(fr.Modulus(), mleB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(mle))) - 1 circuit := &MultilinearSumcheckCircuit[FR]{ @@ -56,7 +56,7 @@ func testMultilinearSumcheckInstance[FR emulated.FieldParams](t *testing.T, curr assignment := &MultilinearSumcheckCircuit[FR]{ Function: polynomial.ValueOfMultilinear[FR](mleB), Claim: emulated.ValueOf[FR](value), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), } err = test.IsSolved(circuit, assignment, current) assert.NoError(err) @@ -92,7 +92,7 @@ func getChallengeEvaluationPoints[FR emulated.FieldParams](inputs [][]*big.Int) return } -type mulGate1[AE arithEngine[E], E element] struct{} +type mulGate1[AE ArithEngine[E], E element] struct{} func (m mulGate1[AE, E]) NbInputs() int { return 2 } func (m mulGate1[AE, E]) Degree() int { return 2 } @@ -133,7 +133,7 @@ func (c *MulGateSumcheck[FR]) Define(api frontend.API) error { for i := range c.EvaluationPoints { evalPoints[i] = polynomial.FromSlice[FR](c.EvaluationPoints[i]) } - claim, err := newGate[FR](api, mulGate1[*emuEngine[FR], *emulated.Element[FR]]{}, inputs, evalPoints, claimedEvals) + claim, err := newGate[FR](api, mulGate1[*EmuEngine[FR], *emulated.Element[FR]]{}, inputs, evalPoints, claimedEvals) if err != nil { return fmt.Errorf("new gate claim: %w", err) } @@ -145,7 +145,7 @@ func (c *MulGateSumcheck[FR]) Define(api frontend.API) error { func testMulGate1SumcheckInstance[FR emulated.FieldParams](t *testing.T, current *big.Int, inputs [][]int) { var fr FR - var nativeGate mulGate1[*bigIntEngine, *big.Int] + var nativeGate mulGate1[*BigIntEngine, *big.Int] assert := test.NewAssert(t) inputB := make([][]*big.Int, len(inputs)) for i := range inputB { @@ -157,7 +157,7 @@ func testMulGate1SumcheckInstance[FR emulated.FieldParams](t *testing.T, current evalPointsB, evalPointsPH, evalPointsC := getChallengeEvaluationPoints[FR](inputB) claim, evals, err := newNativeGate(fr.Modulus(), nativeGate, inputB, evalPointsB) assert.NoError(err) - proof, err := prove(current, fr.Modulus(), claim) + proof, err := Prove(current, fr.Modulus(), claim) assert.NoError(err) nbVars := bits.Len(uint(len(inputs[0]))) - 1 circuit := &MulGateSumcheck[FR]{ @@ -168,7 +168,7 @@ func testMulGate1SumcheckInstance[FR emulated.FieldParams](t *testing.T, current } assignment := &MulGateSumcheck[FR]{ Inputs: make([][]emulated.Element[FR], len(inputs)), - Proof: valueOfProof[FR](proof), + Proof: ValueOfProof[FR](proof), EvaluationPoints: evalPointsC, Claimed: []emulated.Element[FR]{emulated.ValueOf[FR](evals[0])}, } diff --git a/std/recursion/sumcheck/test_vectors/mimc_five_levels_two_instances._json b/std/recursion/sumcheck/test_vectors/mimc_five_levels_two_instances._json new file mode 100644 index 0000000000..446d23fdb2 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/mimc_five_levels_two_instances._json @@ -0,0 +1,7 @@ +{ + "hash": {"type": "const", "val": -1}, + "circuit": "resources/mimc_five_levels.json", + "input": [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]], + "output": [[4, 3]], + "proof": [[{"partialSumPolys":[[3,4]],"finalEvalProof":[3]}],[{"partialSumPolys":null,"finalEvalProof":null}]] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/mimc_five_levels.json b/std/recursion/sumcheck/test_vectors/resources/mimc_five_levels.json new file mode 100644 index 0000000000..3dd74f42b5 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/mimc_five_levels.json @@ -0,0 +1,36 @@ +[ + [ + { + "gate": "mimc", + "inputs": [[1,0], [5,5]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[2,0], [5,4]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[3,0], [5,3]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[4,0], [5,2]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[5,0], [5,1]] + } + ], + [ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} + ] +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_identity_gate.json b/std/recursion/sumcheck/test_vectors/resources/single_identity_gate.json new file mode 100644 index 0000000000..a44066c7b4 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_identity_gate.json @@ -0,0 +1,10 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_input_two_identity_gates.json b/std/recursion/sumcheck/test_vectors/resources/single_input_two_identity_gates.json new file mode 100644 index 0000000000..6181784fa8 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_input_two_identity_gates.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_input_two_outs.json b/std/recursion/sumcheck/test_vectors/resources/single_input_two_outs.json new file mode 100644 index 0000000000..c577c1cace --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_input_two_outs.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_mimc_gate.json b/std/recursion/sumcheck/test_vectors/resources/single_mimc_gate.json new file mode 100644 index 0000000000..c89e7d52ae --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_mimc_gate.json @@ -0,0 +1,7 @@ +[ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + { + "gate": "mimc", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/single_mul_gate.json b/std/recursion/sumcheck/test_vectors/resources/single_mul_gate.json new file mode 100644 index 0000000000..0f65a07edf --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/single_mul_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/two_identity_gates_composed_single_input.json b/std/recursion/sumcheck/test_vectors/resources/two_identity_gates_composed_single_input.json new file mode 100644 index 0000000000..26681c2f89 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/two_identity_gates_composed_single_input.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/resources/two_inputs_select-input-3_gate.json b/std/recursion/sumcheck/test_vectors/resources/two_inputs_select-input-3_gate.json new file mode 100644 index 0000000000..cdbdb3b471 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/resources/two_inputs_select-input-3_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "select-input-3", + "inputs": [0,0,1] + } +] \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_identity_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/single_identity_gate_two_instances.json new file mode 100644 index 0000000000..ce326d0a63 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_identity_gate_two_instances.json @@ -0,0 +1,36 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_identity_gate.json", + "input": [ + [ + 4, + 3 + ] + ], + "output": [ + [ + 4, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5 + ], + "partialSumPolys": [ + [ + -3, + -8 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_input_two_identity_gates_two_instances.json b/std/recursion/sumcheck/test_vectors/single_input_two_identity_gates_two_instances.json new file mode 100644 index 0000000000..2c95f044f2 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_input_two_identity_gates_two_instances.json @@ -0,0 +1,56 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_identity_gates.json", + "input": [ + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ], + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_input_two_outs_two_instances.json b/std/recursion/sumcheck/test_vectors/single_input_two_outs_two_instances.json new file mode 100644 index 0000000000..d348303d0e --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_input_two_outs_two_instances.json @@ -0,0 +1,57 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_outs.json", + "input": [ + [ + 1, + 2 + ] + ], + "output": [ + [ + 1, + 4 + ], + [ + 1, + 2 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -4, + -36, + -112 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -2, + -12 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_mimc_gate_four_instances.json b/std/recursion/sumcheck/test_vectors/single_mimc_gate_four_instances.json new file mode 100644 index 0000000000..525459ecb1 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_mimc_gate_four_instances.json @@ -0,0 +1,67 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1, + 2, + 1 + ], + [ + 1, + 2, + 2, + 1 + ] + ], + "output": [ + [ + 128, + 2187, + 16384, + 128 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + -3 + ], + "partialSumPolys": [ + [ + -32640, + -2239484, + -29360128, + "-200000010", + "-931628672", + "-3373267120", + "-10200858624", + "-26939400158" + ], + [ + -81920, + -41943040, + "-1254113280", + "-13421772800", + "-83200000000", + "-366917713920", + "-1281828208640", + "-3779571220480" + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_mimc_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/single_mimc_gate_two_instances.json new file mode 100644 index 0000000000..7fa23ce4b1 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_mimc_gate_two_instances.json @@ -0,0 +1,51 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ] + ], + "output": [ + [ + 128, + 2187 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 1, + 0 + ], + "partialSumPolys": [ + [ + -2187, + -65536, + -546875, + -2799360, + -10706059, + -33554432, + -90876411, + "-220000000" + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/single_mul_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/single_mul_gate_two_instances.json new file mode 100644 index 0000000000..75c1d59c3d --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/single_mul_gate_two_instances.json @@ -0,0 +1,46 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mul_gate.json", + "input": [ + [ + 4, + 3 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 8, + 9 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5, + 1 + ], + "partialSumPolys": [ + [ + -9, + -32, + -35 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/std/recursion/sumcheck/test_vectors/two_identity_gates_composed_single_input_two_instances.json new file mode 100644 index 0000000000..10e5f1ff3c --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -0,0 +1,47 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_identity_gates_composed_single_input.json", + "input": [ + [ + 2, + 1 + ] + ], + "output": [ + [ + 2, + 1 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/std/recursion/sumcheck/test_vectors/two_inputs_select-input-3_gate_two_instances.json new file mode 100644 index 0000000000..19e127df71 --- /dev/null +++ b/std/recursion/sumcheck/test_vectors/two_inputs_select-input-3_gate_two_instances.json @@ -0,0 +1,45 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_inputs_select-input-3_gate.json", + "input": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/std/recursion/sumcheck/verifier.go b/std/recursion/sumcheck/verifier.go index 6674453ea8..4224d8a56d 100644 --- a/std/recursion/sumcheck/verifier.go +++ b/std/recursion/sumcheck/verifier.go @@ -9,25 +9,25 @@ import ( "github.com/consensys/gnark/std/recursion" ) -type config struct { +type Config struct { prefix string } // Option allows to alter the sumcheck verifier behaviour. -type Option func(c *config) error +type Option func(c *Config) error // WithClaimPrefix prepends the given string to the challenge names when // computing the challenges inside the sumcheck verifier. The option is used in // a higher level protocols to ensure that sumcheck claims are not interchanged. func WithClaimPrefix(prefix string) Option { - return func(c *config) error { + return func(c *Config) error { c.prefix = prefix return nil } } -func newConfig(opts ...Option) (*config, error) { - cfg := new(config) +func NewConfig(opts ...Option) (*Config, error) { + cfg := new(Config) for i := range opts { if err := opts[i](cfg); err != nil { return nil, fmt.Errorf("apply option %d: %w", i, err) @@ -37,7 +37,7 @@ func newConfig(opts ...Option) (*config, error) { } type verifyCfg[FR emulated.FieldParams] struct { - baseChallenges []emulated.Element[FR] + BaseChallenges []emulated.Element[FR] } // VerifyOption allows to alter the behaviour of the single sumcheck proof verification. @@ -48,13 +48,13 @@ type VerifyOption[FR emulated.FieldParams] func(c *verifyCfg[FR]) error func WithBaseChallenges[FR emulated.FieldParams](baseChallenges []*emulated.Element[FR]) VerifyOption[FR] { return func(c *verifyCfg[FR]) error { for i := range baseChallenges { - c.baseChallenges = append(c.baseChallenges, *baseChallenges[i]) + c.BaseChallenges = append(c.BaseChallenges, *baseChallenges[i]) } return nil } } -func newVerificationConfig[FR emulated.FieldParams](opts ...VerifyOption[FR]) (*verifyCfg[FR], error) { +func NewVerificationConfig[FR emulated.FieldParams](opts ...VerifyOption[FR]) (*verifyCfg[FR], error) { cfg := new(verifyCfg[FR]) for i := range opts { if err := opts[i](cfg); err != nil { @@ -69,14 +69,14 @@ type Verifier[FR emulated.FieldParams] struct { api frontend.API f *emulated.Field[FR] p *polynomial.Polynomial[FR] - *config + *Config } // NewVerifier initializes a new sumcheck verifier for the parametric emulated // field FR. It returns an error if the given options are invalid or when // initializing emulated arithmetic fails. func NewVerifier[FR emulated.FieldParams](api frontend.API, opts ...Option) (*Verifier[FR], error) { - cfg, err := newConfig(opts...) + cfg, err := NewConfig(opts...) if err != nil { return nil, fmt.Errorf("new configuration: %w", err) } @@ -92,14 +92,14 @@ func NewVerifier[FR emulated.FieldParams](api frontend.API, opts ...Option) (*Ve api: api, f: f, p: p, - config: cfg, + Config: cfg, }, nil } // Verify verifies the sumcheck proof for the given (lazy) claims. func (v *Verifier[FR]) Verify(claims LazyClaims[FR], proof Proof[FR], opts ...VerifyOption[FR]) error { var fr FR - cfg, err := newVerificationConfig(opts...) + cfg, err := NewVerificationConfig(opts...) if err != nil { return fmt.Errorf("verification opts: %w", err) } @@ -109,7 +109,7 @@ func (v *Verifier[FR]) Verify(claims LazyClaims[FR], proof Proof[FR], opts ...Ve return fmt.Errorf("new transcript: %w", err) } // bind challenge from previous round if it is a continuation - if err = v.bindChallenge(fs, challengeNames[0], cfg.baseChallenges); err != nil { + if err = v.bindChallenge(fs, challengeNames[0], cfg.BaseChallenges); err != nil { return fmt.Errorf("base: %w", err) } diff --git a/std/sumcheck/sumcheck.go b/std/sumcheck/sumcheck.go index de3689cbb8..9a278bf7e7 100644 --- a/std/sumcheck/sumcheck.go +++ b/std/sumcheck/sumcheck.go @@ -20,8 +20,8 @@ type LazyClaims interface { // Proof of a multi-sumcheck statement. type Proof struct { - PartialSumPolys []polynomial.Polynomial - FinalEvalProof interface{} + RoundPolyEvaluations []polynomial.Polynomial + FinalEvalProof interface{} } func setupTranscript(api frontend.API, claimsNum int, varsNum int, settings *fiatshamir.Settings) ([]string, error) { @@ -83,18 +83,17 @@ func Verify(api frontend.API, claims LazyClaims, proof Proof, transcriptSettings gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(api, combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { - partialSumPoly := proof.PartialSumPolys[j] //proof.PartialSumPolys(j) - if len(partialSumPoly) != claims.Degree(j) { + roundPolyEvaluation := proof.RoundPolyEvaluations[j] //proof.RoundPolyEvaluations(j) + if len(roundPolyEvaluation) != claims.Degree(j) { return fmt.Errorf("malformed proof") //Malformed proof } - copy(gJ[1:], partialSumPoly) - gJ[0] = api.Sub(gJR, partialSumPoly[0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + copy(gJ[1:], roundPolyEvaluation) + gJ[0] = api.Sub(gJR, roundPolyEvaluation[0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) // gJ is ready //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + if r[j], err = next(transcript, proof.RoundPolyEvaluations[j], &remainingChallengeNames); err != nil { return err }