From 8d759a826f4593ba204626203019e34fe669d76e Mon Sep 17 00:00:00 2001 From: David Florness Date: Tue, 22 Mar 2022 18:52:30 -0400 Subject: [PATCH] Upgrade gnark from v0.5.2 to v0.6.4 gnark v0.6.0 contains many breaking changes https://github.com/ConsenSys/gnark/blob/master/CHANGELOG.md#v060---2022-01-03 The tallyard version is also bumped to v0.5.0 since the SeedPart param in the JoinElection message (xyz.tallyard.join) has been dropped. --- cmd/tallyard/main.go | 2 +- election/election.go | 21 ------------------ election/marshal_test.go | 28 ++++------------------- election/msg.go | 48 +++++++++++++++++----------------------- election/voter.go | 43 +++++++++++++++++------------------ go.mod | 31 +++++++++++++++++++++----- go.sum | 18 ++++++++++----- math/hash.go | 11 --------- math/lagrange_test.go | 2 +- math/poly.go | 42 ++++++++++++++++++++--------------- math/zk.go | 24 +++++++++----------- version.go | 2 +- 12 files changed, 119 insertions(+), 153 deletions(-) delete mode 100644 math/hash.go diff --git a/cmd/tallyard/main.go b/cmd/tallyard/main.go index d33bc94..2f59392 100644 --- a/cmd/tallyard/main.go +++ b/cmd/tallyard/main.go @@ -174,7 +174,7 @@ func main() { for i, joinID := range *el.FinalJoinIDs { inputs[i] = &el.Joins[joinID].Input } - el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeed(), inputs) + el.LocalVoter.Poly = math.NewRandomPoly(*ballot, inputs) el.Save() } diff --git a/election/election.go b/election/election.go index 2acaa3d..f6b8993 100644 --- a/election/election.go +++ b/election/election.go @@ -7,7 +7,6 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "tallyard.xyz/math" ) type Election struct { @@ -25,8 +24,6 @@ type Election struct { // Save election to disk. Set by the containing ElectionsMap at runtime Save func() `json:"-"` - - hashSeed *string `json:"-"` } func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election { @@ -56,21 +53,3 @@ func (el *Election) UnmarshalJSON(b []byte) error { } return nil } - -func (el *Election) GetHashSeed() string { - if el.hashSeed != nil { - return *el.hashSeed - } - - if el.FinalJoinIDs == nil { - panic("GatHashSeeds called before election started") - } - - seedParts := make([][]byte, len(*el.FinalJoinIDs)) - for i, joinID := range *el.FinalJoinIDs { - seedParts[i] = el.Joins[joinID].SeedPart - } - hashSeed := math.HashSeedFromSeedParts(seedParts) - el.hashSeed = &hashSeed - return hashSeed -} diff --git a/election/marshal_test.go b/election/marshal_test.go index 9b18c46..9f73c3a 100644 --- a/election/marshal_test.go +++ b/election/marshal_test.go @@ -39,18 +39,6 @@ func randomLocalVoter(t *testing.T) *LocalVoter { // PubKey / PrivKey pubKey, privKey, err := box.GenerateKey(rand.Reader) - // HashSeed(s) - seedParts := make([][]byte, numVoters) - for i := range seedParts { - var seedPart [32]byte - if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil { - t.Fatal(err) - } - seedParts[i] = seedPart[:] - } - - localHashSeed := seedParts[joinIDIndex] - // JoinEvt joinEvt := &event.Event{ Sender: "@test:example.org", @@ -64,16 +52,13 @@ func randomLocalVoter(t *testing.T) *LocalVoter { CreateID: "createID", Input: base64.StdEncoding.EncodeToString(inputBytes[:]), PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]), - SeedPart: base64.StdEncoding.EncodeToString(localHashSeed), }, }, } - localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, localHashSeed), privKey) + localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey), privKey) localVoter.JoinIDIndex = &joinIDIndex - hashSeed := math.HashSeedFromSeedParts(seedParts) - // Output localVoter.Output, err = new(fr.Element).SetRandom() if err != nil { @@ -112,7 +97,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter { const numCandidates = 5 { - r1cs, err := math.EvalsCircuitR1CS(hashSeed, numVoters, numCandidates) + r1cs, err := math.EvalsCircuitR1CS(numVoters, numCandidates) if err != nil { t.Fatal(err) } @@ -128,7 +113,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter { } { - r1cs, err := math.SumCircuitR1CS(hashSeed, numVoters) + r1cs, err := math.SumCircuitR1CS(numVoters) if err != nil { t.Fatal(err) } @@ -151,7 +136,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter { } // Poly - localVoter.Poly = math.NewRandomPoly(ballot, hashSeed, inputs) + localVoter.Poly = math.NewRandomPoly(ballot, inputs) return localVoter } @@ -186,11 +171,6 @@ func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) { t.Errorf("pubkeys not equal") } - // HashSeed - if bytes.Compare(lv1.SeedPart, lv2.SeedPart) != 0 { - t.Errorf("hash seeds not equal") - } - // JoinIDIndex if *lv2.JoinIDIndex != *lv1.JoinIDIndex { t.Error("join ID indices not equal") diff --git a/election/msg.go b/election/msg.go index 51693aa..32fa436 100644 --- a/election/msg.go +++ b/election/msg.go @@ -69,7 +69,6 @@ type JoinElectionContent struct { CreateID id.EventID `json:"create_id"` Input string `json:"input"` PubKey string `json:"pub_key"` - SeedPart string `json:"seed_part"` } type StartElectionContent struct { @@ -305,21 +304,6 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success return } - // SeedPart - if content.SeedPart == "" { - warnf("the seed part is empty") - return - } - seedPart, err := base64.StdEncoding.DecodeString(content.SeedPart) - if err != nil { - warnf("we couldn't decode their seed part: %s", err) - return - } - if len(seedPart) < 32 { - warnf("their seed part is fewer than 32 bytes") - return - } - // Input byts, err := base64.StdEncoding.DecodeString(content.Input) if err != nil { @@ -354,7 +338,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success defer el.Save() defer el.Unlock() - el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, seedPart) + el.Joins[evt.ID] = NewVoter(input, evt, &pubKey) return true } @@ -740,14 +724,13 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { var publicCircuit math.EvalsCircuit // public - publicCircuit.HashSeed = el.GetHashSeed() publicCircuit.Inputs = make([]frontend.Variable, len(FinalJoinIDs)) for i, joinID := range FinalJoinIDs { - publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input) + publicCircuit.Inputs[i] = &el.Joins[joinID].Input } publicCircuit.OutputHashes = make([]frontend.Variable, len(FinalJoinIDs)) for i, outputHash := range outputHashes { - publicCircuit.OutputHashes[i].Assign(outputHash) + publicCircuit.OutputHashes[i] = outputHash } // private @@ -760,7 +743,12 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { errorf("our evals verifying key is nil") return } - err := groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk(), &publicCircuit) + witness, err := frontend.NewWitness(&publicCircuit, ecc.BLS12_381, frontend.PublicOnly()) + if err != nil { + errorf("couldn't create witness: %w", err) + return + } + err = groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk(), witness) if err != nil { warnf("poly eval proof verification failed: %s", err) return @@ -770,7 +758,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { // ensure our output hashes to the given hash { outputBytes := output.Bytes() - outputHash, err := mimc.Sum(el.GetHashSeed(), outputBytes[:]) + outputHash, err := mimc.Sum(outputBytes[:]) if err != nil { errorf("couldn't hash output: %s", err) return @@ -939,16 +927,15 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) { for i, joinID := range *el.FinalJoinIDs { evaler := el.Joins[joinID] if uint(i) == voterJoinIDIndex { - hashSelects[i].Assign(1) + hashSelects[i] = 1 } else { - hashSelects[i].Assign(0) + hashSelects[i] = 0 } - outputHashes[i].Assign((*evaler.OutputHashes)[voterJoinIDIndex]) + outputHashes[i] = (*evaler.OutputHashes)[voterJoinIDIndex] } - publicCircuit.HashSeed = el.GetHashSeed() publicCircuit.HashSelects = hashSelects publicCircuit.OutputHashes = outputHashes - publicCircuit.Sum.Assign(sum) + publicCircuit.Sum = sum publicCircuit.Outputs = make([]frontend.Variable, n) if el.LocalVoter.SumVerifyingKey == nil { @@ -957,7 +944,12 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) { errorf("our sum verifying key is nil") return } - err := groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk(), &publicCircuit) + witness, err := frontend.NewWitness(&publicCircuit, ecc.BLS12_381, frontend.PublicOnly()) + if err != nil { + errorf("couldn't create witness: %w", err) + return + } + err = groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk(), witness) if err != nil { warnf("sum proof verification failed: %s", err) return diff --git a/election/voter.go b/election/voter.go index 510c023..dd3ed94 100644 --- a/election/voter.go +++ b/election/voter.go @@ -8,6 +8,7 @@ import ( "fmt" "io" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" @@ -23,7 +24,6 @@ type Voter struct { Input fr.Element `json:"input"` JoinEvt event.Event `json:"join_evt"` PubKey [32]byte `json:"pub_key"` - SeedPart []byte `json:"seed_part"` JoinIDIndex *uint `json:"join_id_index,omitempty"` Output *fr.Element `json:"output,omitempty"` @@ -47,12 +47,11 @@ type LocalVoter struct { SumVerifyingKey *VerifyingKeyFile `json:"sum_verifying_key,omitempty"` } -func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, seedPart []byte) *Voter { +func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte) *Voter { return &Voter{ Input: *new(fr.Element).Set(input), JoinEvt: *joinEvt, PubKey: *pubKey, - SeedPart: seedPart, } } @@ -93,11 +92,6 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore) } inputBytes := input.Bytes() - var seedBytes [32]byte - if _, err := io.ReadFull(rand.Reader, seedBytes[:]); err != nil { - return fmt.Errorf("couldn't read random bytes: %s", err) - } - commitmentHash, err := CalculateCommitment(el.CreateEvt.Content) if err != nil { return fmt.Errorf("couldn't calculate join event's commitment: %s", err) @@ -110,7 +104,6 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore) CreateID: el.CreateEvt.ID, Input: base64.StdEncoding.EncodeToString(inputBytes[:]), PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]), - SeedPart: base64.StdEncoding.EncodeToString(seedBytes[:]), }) if err != nil { return fmt.Errorf("couldn't send join messages: %s", err) @@ -192,7 +185,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto evalVerifyingKey groth16.VerifyingKey ) { - r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs), len(el.Candidates)) + r1cs, err := math.EvalsCircuitR1CS(len(*el.FinalJoinIDs), len(el.Candidates)) if err != nil { return fmt.Errorf("couldn't compile eval circuit: %s", err) } @@ -222,7 +215,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto sumVerifyingKey groth16.VerifyingKey ) { - r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs)) + r1cs, err := math.SumCircuitR1CS(len(*el.FinalJoinIDs)) if err != nil { return fmt.Errorf("couldn't compile sum circuit: %s", err) } @@ -399,7 +392,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro evalsIDs[i] = *voter.EvalsID } - r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs)) + r1cs, err := math.SumCircuitR1CS(len(*el.FinalJoinIDs)) proofs := make([]string, numVoters) { @@ -407,29 +400,33 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro outputHashes := make([]frontend.Variable, numVoters) outputs := make([]frontend.Variable, numVoters) - var witness math.SumCircuit - witness.HashSeed = el.GetHashSeed() - witness.HashSelects = hashSelects - witness.OutputHashes = outputHashes - witness.Sum.Assign(el.LocalVoter.Sum) - witness.Outputs = outputs + var circuit math.SumCircuit + circuit.HashSelects = hashSelects + circuit.OutputHashes = outputHashes + circuit.Sum = el.LocalVoter.Sum + circuit.Outputs = outputs ourJoinIDIndex := *el.LocalVoter.JoinIDIndex for i, joinID := range *el.FinalJoinIDs { if uint(i) == ourJoinIDIndex { - hashSelects[i].Assign(1) + hashSelects[i] = 1 } else { - hashSelects[i].Assign(0) + hashSelects[i] = 0 } evaler := el.Joins[joinID] outputHashForUs := (*evaler.OutputHashes)[ourJoinIDIndex] - outputHashes[i].Assign(outputHashForUs) - outputs[i].Assign(evaler.Output) + outputHashes[i] = outputHashForUs + outputs[i] = evaler.Output + } + + witness, err := frontend.NewWitness(&circuit, ecc.BLS12_381) + if err != nil { + return fmt.Errorf("couldn't create witness: %w", err) } for i, joinID := range *el.FinalJoinIDs { voter := *el.Joins[joinID] - proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk(), &witness) + proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk(), witness) if err != nil { el.RUnlock() return fmt.Errorf("couldn't prove sum: %s", err) diff --git a/go.mod b/go.mod index 25ef5b3..b744346 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,38 @@ module tallyard.xyz -go 1.13 +go 1.17 require ( - github.com/consensys/gnark v0.5.2 - github.com/consensys/gnark-crypto v0.5.3 - github.com/fxamacker/cbor/v2 v2.3.1 // indirect + github.com/consensys/gnark v0.6.4 + github.com/consensys/gnark-crypto v0.6.1 github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385 github.com/kyoh86/xdg v1.2.0 github.com/rivo/tview v0.0.0-20211202162923-2a6de950f73b github.com/sirupsen/logrus v1.8.1 golang.org/x/crypto v0.0.0-20220312131142-6068a2e6cfdc golang.org/x/mod v0.5.1 - golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect maunium.net/go/mautrix v0.10.11 ) + +require ( + github.com/fxamacker/cbor/v2 v2.3.1 // indirect + github.com/gdamore/encoding v1.0.0 // indirect + github.com/gorilla/mux v1.8.0 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/mattn/go-runewidth v0.0.13 // indirect + github.com/mmcloughlin/addchain v0.4.0 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/tidwall/gjson v1.14.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.4 // indirect + github.com/x448/float16 v0.8.4 // indirect + golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect + golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect + golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect + golang.org/x/text v0.3.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + maunium.net/go/maulogger/v2 v2.3.2 // indirect +) diff --git a/go.sum b/go.sum index 138926e..c66c45c 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,9 @@ -github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark v0.5.2 h1:/TTBStGJXkJqFVYFT7YnWmd0PedZlavUb7qOHO2UMEg= -github.com/consensys/gnark v0.5.2/go.mod h1:gaY1Ij1sp3TnLexb6y9y0KslzqVDvRg+XKldbXXK7ss= -github.com/consensys/gnark-crypto v0.5.3 h1:4xLFGZR3NWEH2zy+YzvzHicpToQR8FXFbfLNvpGB+rE= -github.com/consensys/gnark-crypto v0.5.3/go.mod h1:hOdPlWQV1gDLp7faZVeg8Y0iEPFaOUnCc4XeCCk96p0= +github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/gnark v0.6.4 h1:V+iSUwRQvTqug3z+Sb8ve5C9JWTU8bP29sU3HCmlARc= +github.com/consensys/gnark v0.6.4/go.mod h1:rJwNZk2xhK/V2yYlqBS9Y4FYvZ1347lWejsIr2HRVak= +github.com/consensys/gnark-crypto v0.6.1 h1:MuWaJyWzSw8wQUOfiZOlRwYjfweIj8dM/u2NN6m0O04= +github.com/consensys/gnark-crypto v0.6.1/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -14,6 +15,7 @@ github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo github.com/gdamore/tcell/v2 v2.4.1-0.20210905002822-f057f0a857a1/go.mod h1:Az6Jt+M5idSED2YPGtwnfJV0kXohgdCBPmHGSYc1r04= github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385 h1:O5oaOCRcXvNnsPikhB6xGd4a1bbfgcuFQCQgDB4tM7Y= github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385/go.mod h1:I8YJFI9gzgl4dHi9UlRDZosCW+jYkDA37AXmXvL51w4= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= @@ -21,8 +23,9 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kyoh86/xdg v1.2.0 h1:CERuT/ShdTDj+A2UaX3hQ3mOV369+Sj+wyn2nIRIIkI= github.com/kyoh86/xdg v1.2.0/go.mod h1:/mg8zwu1+qe76oTFUBnyS7rJzk7LLC0VGEzJyJ19DHs= github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= @@ -32,6 +35,9 @@ github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.11/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= +github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= +github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/tview v0.0.0-20211202162923-2a6de950f73b h1:EMgbQ+bOHWkl0Ptano8M0yrzVZkxans+Vfv7ox/EtO8= diff --git a/math/hash.go b/math/hash.go deleted file mode 100644 index 0209d9b..0000000 --- a/math/hash.go +++ /dev/null @@ -1,11 +0,0 @@ -package math - -import "golang.org/x/crypto/sha3" - -func HashSeedFromSeedParts(seedParts [][]byte) string { - res := sha3.Sum256(seedParts[0]) - for _, part := range seedParts[1:] { - res = sha3.Sum256(append(res[:], part...)) - } - return string(res[:]) -} diff --git a/math/lagrange_test.go b/math/lagrange_test.go index a96b391..d221dd9 100644 --- a/math/lagrange_test.go +++ b/math/lagrange_test.go @@ -74,7 +74,7 @@ func TestRandomPoly(t *testing.T) { seedParts[i] = seedPart[:] } - poly := NewRandomPoly(ballot, HashSeedFromSeedParts(seedParts), inputs) + poly := NewRandomPoly(ballot, inputs) points := make([]Point, n) for i := 0; i < n; i++ { genx: diff --git a/math/poly.go b/math/poly.go index 0e0ae46..b6f3e1b 100644 --- a/math/poly.go +++ b/math/poly.go @@ -3,16 +3,17 @@ package math import ( "fmt" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/frontend" ) type Poly struct { Ballot [][]byte `json:"ballot"` Coeffs []*fr.Element `json:"coeffs"` - HashSeed string `json:"hash_seed"` Inputs []*fr.Element `json:"inputs"` cache *polyCache `json:"-"` @@ -23,7 +24,7 @@ type polyCache struct { constant *fr.Element outputHashes map[*fr.Element]outputCache r1cs frontend.CompiledConstraintSystem - witness EvalsCircuit + witness *witness.Witness } type outputCache struct { @@ -33,7 +34,7 @@ type outputCache struct { // NewRandomPoly generates a random polynomial of degree numVoters-1, using // `ballot' as the constant term -func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly { +func NewRandomPoly(ballot [][]byte, inputs []*fr.Element) *Poly { coeffs := make([]*fr.Element, len(inputs) - 1) var err error @@ -44,7 +45,7 @@ func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly } } - return &Poly{ballot, coeffs, hashSeed, inputs, nil} + return &Poly{ballot, coeffs, inputs, nil} } func (p *Poly) setupCache() { @@ -75,7 +76,7 @@ func (p *Poly) setupCache() { for _, input := range p.Inputs { output := p.eval(input) byts := output.Bytes() - hash, err := mimc.Sum(p.HashSeed, byts[:]) + hash, err := mimc.Sum(byts[:]) if err != nil { panic(err) } @@ -85,33 +86,38 @@ func (p *Poly) setupCache() { // compile R1CS { var err error - p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeed, len(p.Inputs), len(p.Ballot)) + p.cache.r1cs, err = EvalsCircuitR1CS(len(p.Inputs), len(p.Ballot)) if err != nil { panic(err) } } + circuit := &EvalsCircuit{} // create witness { - witness := &p.cache.witness // public - witness.HashSeed = p.HashSeed - witness.Inputs = make([]frontend.Variable, len(p.Inputs)) - witness.OutputHashes = make([]frontend.Variable, len(p.Inputs)) + circuit.Inputs = make([]frontend.Variable, len(p.Inputs)) + circuit.OutputHashes = make([]frontend.Variable, len(p.Inputs)) for i, input := range p.Inputs { - witness.Inputs[i].Assign(input) - witness.OutputHashes[i].Assign(p.cache.outputHashes[input].hash) + circuit.Inputs[i] = input + circuit.OutputHashes[i] = p.cache.outputHashes[input].hash } // private - witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0])) + circuit.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0])) for i, bit := range p.cache.ballotBits { - witness.BallotBits[i].Assign(bit) + circuit.BallotBits[i] = bit } - witness.Coeffs = make([]frontend.Variable, len(p.Coeffs)) + circuit.Coeffs = make([]frontend.Variable, len(p.Coeffs)) for i, coeff := range p.Coeffs { - witness.Coeffs[i].Assign(coeff) + circuit.Coeffs[i] = coeff } - witness.Constant.Assign(p.cache.constant) + circuit.Constant = p.cache.constant + } + + var err error + p.cache.witness, err = frontend.NewWitness(circuit, ecc.BLS12_381) + if err != nil { + panic(fmt.Errorf("couldn't create witness: %w", err)) } } @@ -120,7 +126,7 @@ func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey) (* p.setupCache() } - proof, err := groth16.Prove(p.cache.r1cs, provingKey, &p.cache.witness) + proof, err := groth16.Prove(p.cache.r1cs, provingKey, p.cache.witness) if err != nil { panic(err) } diff --git a/math/zk.go b/math/zk.go index 0c27175..28ad359 100644 --- a/math/zk.go +++ b/math/zk.go @@ -10,7 +10,6 @@ import ( ) type EvalsCircuit struct { - HashSeed string Inputs []frontend.Variable `gnark:",public"` OutputHashes []frontend.Variable `gnark:",public"` @@ -21,7 +20,7 @@ type EvalsCircuit struct { Constant frontend.Variable `gnark:",secret"` } -func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error { +func (circuit *EvalsCircuit) Define(api frontend.API) error { if len(circuit.Inputs) != len(circuit.OutputHashes) { return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)") } @@ -44,7 +43,7 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error { } // hash - mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381, api) + mimc, err := mimc.NewMiMC(api) if err != nil { return err } @@ -55,9 +54,9 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error { // prove the constant is valid (i.e. a string of bits separated by 7 // zeros) // TODO: support more than 255 voters (limit from usage of byte) - zero := api.Constant(0) - shift := api.Constant(1 << 8) - slot := api.Constant(1) + zero := frontend.Variable(0) + shift := frontend.Variable(1 << 8) + slot := frontend.Variable(1) constructedConstant := zero for i := len(circuit.BallotBits)-1; i >= 0; i-- { bit := circuit.BallotBits[i] @@ -70,9 +69,8 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error { return nil } -func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) { +func EvalsCircuitR1CS(numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) { var circuit EvalsCircuit - circuit.HashSeed = hashSeed circuit.Inputs = make([]frontend.Variable, numVoters) circuit.OutputHashes = make([]frontend.Variable, numVoters) circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates) @@ -85,7 +83,6 @@ func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (fronte } type SumCircuit struct { - HashSeed string HashSelects []frontend.Variable `gnark:",public"` OutputHashes []frontend.Variable `gnark:",public"` Sum frontend.Variable `gnark:",public"` @@ -93,7 +90,7 @@ type SumCircuit struct { Outputs []frontend.Variable `gnark:",secret"` } -func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error { +func (circuit *SumCircuit) Define(api frontend.API) error { if len(circuit.HashSelects) != len(circuit.OutputHashes) { return errors.New("len(circuit.HashSelects) != len(circuit.OutputHashes)") } @@ -102,9 +99,9 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error { } // prove the secret outputs hash to public hashes - zero := api.Constant(0) + zero := frontend.Variable(0) - mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381, api) + mimc, err := mimc.NewMiMC(api) if err != nil { return err } @@ -131,9 +128,8 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error { return nil } -func SumCircuitR1CS(hashSeed string, numVoters int) (frontend.CompiledConstraintSystem, error) { +func SumCircuitR1CS(numVoters int) (frontend.CompiledConstraintSystem, error) { var circuit SumCircuit - circuit.HashSeed = hashSeed circuit.HashSelects = make([]frontend.Variable, numVoters) circuit.OutputHashes = make([]frontend.Variable, numVoters) circuit.Outputs = make([]frontend.Variable, numVoters) diff --git a/version.go b/version.go index ab1e34a..a612ee5 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package tallyard -const Version string = "v0.4.5" +const Version string = "v0.5.0" -- 2.38.4