M cmd/tallyard/main.go => cmd/tallyard/main.go +1 -1
@@ 174,7 174,7 @@ func main() {
for i, joinID := range *el.FinalJoinIDs {
inputs[i] = &el.Joins[joinID].Input
}
- el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeed(), inputs)
+ el.LocalVoter.Poly = math.NewRandomPoly(*ballot, inputs)
el.Save()
}
M election/election.go => election/election.go +0 -21
@@ 7,7 7,6 @@ import (
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
- "tallyard.xyz/math"
)
type Election struct {
@@ 25,8 24,6 @@ type Election struct {
// Save election to disk. Set by the containing ElectionsMap at runtime
Save func() `json:"-"`
-
- hashSeed *string `json:"-"`
}
func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election {
@@ 56,21 53,3 @@ func (el *Election) UnmarshalJSON(b []byte) error {
}
return nil
}
-
-func (el *Election) GetHashSeed() string {
- if el.hashSeed != nil {
- return *el.hashSeed
- }
-
- if el.FinalJoinIDs == nil {
- panic("GatHashSeeds called before election started")
- }
-
- seedParts := make([][]byte, len(*el.FinalJoinIDs))
- for i, joinID := range *el.FinalJoinIDs {
- seedParts[i] = el.Joins[joinID].SeedPart
- }
- hashSeed := math.HashSeedFromSeedParts(seedParts)
- el.hashSeed = &hashSeed
- return hashSeed
-}
M election/marshal_test.go => election/marshal_test.go +4 -24
@@ 39,18 39,6 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
// PubKey / PrivKey
pubKey, privKey, err := box.GenerateKey(rand.Reader)
- // HashSeed(s)
- seedParts := make([][]byte, numVoters)
- for i := range seedParts {
- var seedPart [32]byte
- if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil {
- t.Fatal(err)
- }
- seedParts[i] = seedPart[:]
- }
-
- localHashSeed := seedParts[joinIDIndex]
-
// JoinEvt
joinEvt := &event.Event{
Sender: "@test:example.org",
@@ 64,16 52,13 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
CreateID: "createID",
Input: base64.StdEncoding.EncodeToString(inputBytes[:]),
PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]),
- SeedPart: base64.StdEncoding.EncodeToString(localHashSeed),
},
},
}
- localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, localHashSeed), privKey)
+ localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey), privKey)
localVoter.JoinIDIndex = &joinIDIndex
- hashSeed := math.HashSeedFromSeedParts(seedParts)
-
// Output
localVoter.Output, err = new(fr.Element).SetRandom()
if err != nil {
@@ 112,7 97,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
const numCandidates = 5
{
- r1cs, err := math.EvalsCircuitR1CS(hashSeed, numVoters, numCandidates)
+ r1cs, err := math.EvalsCircuitR1CS(numVoters, numCandidates)
if err != nil {
t.Fatal(err)
}
@@ 128,7 113,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
}
{
- r1cs, err := math.SumCircuitR1CS(hashSeed, numVoters)
+ r1cs, err := math.SumCircuitR1CS(numVoters)
if err != nil {
t.Fatal(err)
}
@@ 151,7 136,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
}
// Poly
- localVoter.Poly = math.NewRandomPoly(ballot, hashSeed, inputs)
+ localVoter.Poly = math.NewRandomPoly(ballot, inputs)
return localVoter
}
@@ 186,11 171,6 @@ func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) {
t.Errorf("pubkeys not equal")
}
- // HashSeed
- if bytes.Compare(lv1.SeedPart, lv2.SeedPart) != 0 {
- t.Errorf("hash seeds not equal")
- }
-
// JoinIDIndex
if *lv2.JoinIDIndex != *lv1.JoinIDIndex {
t.Error("join ID indices not equal")
M election/msg.go => election/msg.go +20 -28
@@ 69,7 69,6 @@ type JoinElectionContent struct {
CreateID id.EventID `json:"create_id"`
Input string `json:"input"`
PubKey string `json:"pub_key"`
- SeedPart string `json:"seed_part"`
}
type StartElectionContent struct {
@@ 305,21 304,6 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success
return
}
- // SeedPart
- if content.SeedPart == "" {
- warnf("the seed part is empty")
- return
- }
- seedPart, err := base64.StdEncoding.DecodeString(content.SeedPart)
- if err != nil {
- warnf("we couldn't decode their seed part: %s", err)
- return
- }
- if len(seedPart) < 32 {
- warnf("their seed part is fewer than 32 bytes")
- return
- }
-
// Input
byts, err := base64.StdEncoding.DecodeString(content.Input)
if err != nil {
@@ 354,7 338,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success
defer el.Save()
defer el.Unlock()
- el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, seedPart)
+ el.Joins[evt.ID] = NewVoter(input, evt, &pubKey)
return true
}
@@ 740,14 724,13 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
var publicCircuit math.EvalsCircuit
// public
- publicCircuit.HashSeed = el.GetHashSeed()
publicCircuit.Inputs = make([]frontend.Variable, len(FinalJoinIDs))
for i, joinID := range FinalJoinIDs {
- publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input)
+ publicCircuit.Inputs[i] = &el.Joins[joinID].Input
}
publicCircuit.OutputHashes = make([]frontend.Variable, len(FinalJoinIDs))
for i, outputHash := range outputHashes {
- publicCircuit.OutputHashes[i].Assign(outputHash)
+ publicCircuit.OutputHashes[i] = outputHash
}
// private
@@ 760,7 743,12 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
errorf("our evals verifying key is nil")
return
}
- err := groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk(), &publicCircuit)
+ witness, err := frontend.NewWitness(&publicCircuit, ecc.BLS12_381, frontend.PublicOnly())
+ if err != nil {
+ errorf("couldn't create witness: %w", err)
+ return
+ }
+ err = groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk(), witness)
if err != nil {
warnf("poly eval proof verification failed: %s", err)
return
@@ 770,7 758,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
// ensure our output hashes to the given hash
{
outputBytes := output.Bytes()
- outputHash, err := mimc.Sum(el.GetHashSeed(), outputBytes[:])
+ outputHash, err := mimc.Sum(outputBytes[:])
if err != nil {
errorf("couldn't hash output: %s", err)
return
@@ 939,16 927,15 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
for i, joinID := range *el.FinalJoinIDs {
evaler := el.Joins[joinID]
if uint(i) == voterJoinIDIndex {
- hashSelects[i].Assign(1)
+ hashSelects[i] = 1
} else {
- hashSelects[i].Assign(0)
+ hashSelects[i] = 0
}
- outputHashes[i].Assign((*evaler.OutputHashes)[voterJoinIDIndex])
+ outputHashes[i] = (*evaler.OutputHashes)[voterJoinIDIndex]
}
- publicCircuit.HashSeed = el.GetHashSeed()
publicCircuit.HashSelects = hashSelects
publicCircuit.OutputHashes = outputHashes
- publicCircuit.Sum.Assign(sum)
+ publicCircuit.Sum = sum
publicCircuit.Outputs = make([]frontend.Variable, n)
if el.LocalVoter.SumVerifyingKey == nil {
@@ 957,7 944,12 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
errorf("our sum verifying key is nil")
return
}
- err := groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk(), &publicCircuit)
+ witness, err := frontend.NewWitness(&publicCircuit, ecc.BLS12_381, frontend.PublicOnly())
+ if err != nil {
+ errorf("couldn't create witness: %w", err)
+ return
+ }
+ err = groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk(), witness)
if err != nil {
warnf("sum proof verification failed: %s", err)
return
M election/voter.go => election/voter.go +20 -23
@@ 8,6 8,7 @@ import (
"fmt"
"io"
+ "github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/frontend"
@@ 23,7 24,6 @@ type Voter struct {
Input fr.Element `json:"input"`
JoinEvt event.Event `json:"join_evt"`
PubKey [32]byte `json:"pub_key"`
- SeedPart []byte `json:"seed_part"`
JoinIDIndex *uint `json:"join_id_index,omitempty"`
Output *fr.Element `json:"output,omitempty"`
@@ 47,12 47,11 @@ type LocalVoter struct {
SumVerifyingKey *VerifyingKeyFile `json:"sum_verifying_key,omitempty"`
}
-func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, seedPart []byte) *Voter {
+func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte) *Voter {
return &Voter{
Input: *new(fr.Element).Set(input),
JoinEvt: *joinEvt,
PubKey: *pubKey,
- SeedPart: seedPart,
}
}
@@ 93,11 92,6 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore)
}
inputBytes := input.Bytes()
- var seedBytes [32]byte
- if _, err := io.ReadFull(rand.Reader, seedBytes[:]); err != nil {
- return fmt.Errorf("couldn't read random bytes: %s", err)
- }
-
commitmentHash, err := CalculateCommitment(el.CreateEvt.Content)
if err != nil {
return fmt.Errorf("couldn't calculate join event's commitment: %s", err)
@@ 110,7 104,6 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore)
CreateID: el.CreateEvt.ID,
Input: base64.StdEncoding.EncodeToString(inputBytes[:]),
PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]),
- SeedPart: base64.StdEncoding.EncodeToString(seedBytes[:]),
})
if err != nil {
return fmt.Errorf("couldn't send join messages: %s", err)
@@ 192,7 185,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
evalVerifyingKey groth16.VerifyingKey
)
{
- r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs), len(el.Candidates))
+ r1cs, err := math.EvalsCircuitR1CS(len(*el.FinalJoinIDs), len(el.Candidates))
if err != nil {
return fmt.Errorf("couldn't compile eval circuit: %s", err)
}
@@ 222,7 215,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
sumVerifyingKey groth16.VerifyingKey
)
{
- r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))
+ r1cs, err := math.SumCircuitR1CS(len(*el.FinalJoinIDs))
if err != nil {
return fmt.Errorf("couldn't compile sum circuit: %s", err)
}
@@ 399,7 392,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
evalsIDs[i] = *voter.EvalsID
}
- r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))
+ r1cs, err := math.SumCircuitR1CS(len(*el.FinalJoinIDs))
proofs := make([]string, numVoters)
{
@@ 407,29 400,33 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
outputHashes := make([]frontend.Variable, numVoters)
outputs := make([]frontend.Variable, numVoters)
- var witness math.SumCircuit
- witness.HashSeed = el.GetHashSeed()
- witness.HashSelects = hashSelects
- witness.OutputHashes = outputHashes
- witness.Sum.Assign(el.LocalVoter.Sum)
- witness.Outputs = outputs
+ var circuit math.SumCircuit
+ circuit.HashSelects = hashSelects
+ circuit.OutputHashes = outputHashes
+ circuit.Sum = el.LocalVoter.Sum
+ circuit.Outputs = outputs
ourJoinIDIndex := *el.LocalVoter.JoinIDIndex
for i, joinID := range *el.FinalJoinIDs {
if uint(i) == ourJoinIDIndex {
- hashSelects[i].Assign(1)
+ hashSelects[i] = 1
} else {
- hashSelects[i].Assign(0)
+ hashSelects[i] = 0
}
evaler := el.Joins[joinID]
outputHashForUs := (*evaler.OutputHashes)[ourJoinIDIndex]
- outputHashes[i].Assign(outputHashForUs)
- outputs[i].Assign(evaler.Output)
+ outputHashes[i] = outputHashForUs
+ outputs[i] = evaler.Output
+ }
+
+ witness, err := frontend.NewWitness(&circuit, ecc.BLS12_381)
+ if err != nil {
+ return fmt.Errorf("couldn't create witness: %w", err)
}
for i, joinID := range *el.FinalJoinIDs {
voter := *el.Joins[joinID]
- proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk(), &witness)
+ proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk(), witness)
if err != nil {
el.RUnlock()
return fmt.Errorf("couldn't prove sum: %s", err)
M go.mod => go.mod +26 -5
@@ 1,17 1,38 @@
module tallyard.xyz
-go 1.13
+go 1.17
require (
- github.com/consensys/gnark v0.5.2
- github.com/consensys/gnark-crypto v0.5.3
- github.com/fxamacker/cbor/v2 v2.3.1 // indirect
+ github.com/consensys/gnark v0.6.4
+ github.com/consensys/gnark-crypto v0.6.1
github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385
github.com/kyoh86/xdg v1.2.0
github.com/rivo/tview v0.0.0-20211202162923-2a6de950f73b
github.com/sirupsen/logrus v1.8.1
golang.org/x/crypto v0.0.0-20220312131142-6068a2e6cfdc
golang.org/x/mod v0.5.1
- golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect
maunium.net/go/mautrix v0.10.11
)
+
+require (
+ github.com/fxamacker/cbor/v2 v2.3.1 // indirect
+ github.com/gdamore/encoding v1.0.0 // indirect
+ github.com/gorilla/mux v1.8.0 // indirect
+ github.com/gorilla/websocket v1.4.2 // indirect
+ github.com/kr/text v0.2.0 // indirect
+ github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
+ github.com/mattn/go-runewidth v0.0.13 // indirect
+ github.com/mmcloughlin/addchain v0.4.0 // indirect
+ github.com/rivo/uniseg v0.2.0 // indirect
+ github.com/tidwall/gjson v1.14.0 // indirect
+ github.com/tidwall/match v1.1.1 // indirect
+ github.com/tidwall/pretty v1.2.0 // indirect
+ github.com/tidwall/sjson v1.2.4 // indirect
+ github.com/x448/float16 v0.8.4 // indirect
+ golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect
+ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect
+ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
+ golang.org/x/text v0.3.7 // indirect
+ gopkg.in/yaml.v2 v2.4.0 // indirect
+ maunium.net/go/maulogger/v2 v2.3.2 // indirect
+)
M go.sum => go.sum +12 -6
@@ 1,8 1,9 @@
-github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI=
-github.com/consensys/gnark v0.5.2 h1:/TTBStGJXkJqFVYFT7YnWmd0PedZlavUb7qOHO2UMEg=
-github.com/consensys/gnark v0.5.2/go.mod h1:gaY1Ij1sp3TnLexb6y9y0KslzqVDvRg+XKldbXXK7ss=
-github.com/consensys/gnark-crypto v0.5.3 h1:4xLFGZR3NWEH2zy+YzvzHicpToQR8FXFbfLNvpGB+rE=
-github.com/consensys/gnark-crypto v0.5.3/go.mod h1:hOdPlWQV1gDLp7faZVeg8Y0iEPFaOUnCc4XeCCk96p0=
+github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI=
+github.com/consensys/gnark v0.6.4 h1:V+iSUwRQvTqug3z+Sb8ve5C9JWTU8bP29sU3HCmlARc=
+github.com/consensys/gnark v0.6.4/go.mod h1:rJwNZk2xhK/V2yYlqBS9Y4FYvZ1347lWejsIr2HRVak=
+github.com/consensys/gnark-crypto v0.6.1 h1:MuWaJyWzSw8wQUOfiZOlRwYjfweIj8dM/u2NN6m0O04=
+github.com/consensys/gnark-crypto v0.6.1/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70=
+github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ 14,6 15,7 @@ github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo
github.com/gdamore/tcell/v2 v2.4.1-0.20210905002822-f057f0a857a1/go.mod h1:Az6Jt+M5idSED2YPGtwnfJV0kXohgdCBPmHGSYc1r04=
github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385 h1:O5oaOCRcXvNnsPikhB6xGd4a1bbfgcuFQCQgDB4tM7Y=
github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385/go.mod h1:I8YJFI9gzgl4dHi9UlRDZosCW+jYkDA37AXmXvL51w4=
+github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
@@ 21,8 23,9 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kyoh86/xdg v1.2.0 h1:CERuT/ShdTDj+A2UaX3hQ3mOV369+Sj+wyn2nIRIIkI=
github.com/kyoh86/xdg v1.2.0/go.mod h1:/mg8zwu1+qe76oTFUBnyS7rJzk7LLC0VGEzJyJ19DHs=
github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c=
@@ 32,6 35,9 @@ github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i
github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU=
github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.11/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
+github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY=
+github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU=
+github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/tview v0.0.0-20211202162923-2a6de950f73b h1:EMgbQ+bOHWkl0Ptano8M0yrzVZkxans+Vfv7ox/EtO8=
D math/hash.go => math/hash.go +0 -11
@@ 1,11 0,0 @@
-package math
-
-import "golang.org/x/crypto/sha3"
-
-func HashSeedFromSeedParts(seedParts [][]byte) string {
- res := sha3.Sum256(seedParts[0])
- for _, part := range seedParts[1:] {
- res = sha3.Sum256(append(res[:], part...))
- }
- return string(res[:])
-}
M math/lagrange_test.go => math/lagrange_test.go +1 -1
@@ 74,7 74,7 @@ func TestRandomPoly(t *testing.T) {
seedParts[i] = seedPart[:]
}
- poly := NewRandomPoly(ballot, HashSeedFromSeedParts(seedParts), inputs)
+ poly := NewRandomPoly(ballot, inputs)
points := make([]Point, n)
for i := 0; i < n; i++ {
genx:
M math/poly.go => math/poly.go +24 -18
@@ 3,16 3,17 @@ package math
import (
"fmt"
+ "github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc"
"github.com/consensys/gnark/backend/groth16"
+ "github.com/consensys/gnark/backend/witness"
"github.com/consensys/gnark/frontend"
)
type Poly struct {
Ballot [][]byte `json:"ballot"`
Coeffs []*fr.Element `json:"coeffs"`
- HashSeed string `json:"hash_seed"`
Inputs []*fr.Element `json:"inputs"`
cache *polyCache `json:"-"`
@@ 23,7 24,7 @@ type polyCache struct {
constant *fr.Element
outputHashes map[*fr.Element]outputCache
r1cs frontend.CompiledConstraintSystem
- witness EvalsCircuit
+ witness *witness.Witness
}
type outputCache struct {
@@ 33,7 34,7 @@ type outputCache struct {
// NewRandomPoly generates a random polynomial of degree numVoters-1, using
// `ballot' as the constant term
-func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly {
+func NewRandomPoly(ballot [][]byte, inputs []*fr.Element) *Poly {
coeffs := make([]*fr.Element, len(inputs) - 1)
var err error
@@ 44,7 45,7 @@ func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly
}
}
- return &Poly{ballot, coeffs, hashSeed, inputs, nil}
+ return &Poly{ballot, coeffs, inputs, nil}
}
func (p *Poly) setupCache() {
@@ 75,7 76,7 @@ func (p *Poly) setupCache() {
for _, input := range p.Inputs {
output := p.eval(input)
byts := output.Bytes()
- hash, err := mimc.Sum(p.HashSeed, byts[:])
+ hash, err := mimc.Sum(byts[:])
if err != nil {
panic(err)
}
@@ 85,33 86,38 @@ func (p *Poly) setupCache() {
// compile R1CS
{
var err error
- p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeed, len(p.Inputs), len(p.Ballot))
+ p.cache.r1cs, err = EvalsCircuitR1CS(len(p.Inputs), len(p.Ballot))
if err != nil {
panic(err)
}
}
+ circuit := &EvalsCircuit{}
// create witness
{
- witness := &p.cache.witness
// public
- witness.HashSeed = p.HashSeed
- witness.Inputs = make([]frontend.Variable, len(p.Inputs))
- witness.OutputHashes = make([]frontend.Variable, len(p.Inputs))
+ circuit.Inputs = make([]frontend.Variable, len(p.Inputs))
+ circuit.OutputHashes = make([]frontend.Variable, len(p.Inputs))
for i, input := range p.Inputs {
- witness.Inputs[i].Assign(input)
- witness.OutputHashes[i].Assign(p.cache.outputHashes[input].hash)
+ circuit.Inputs[i] = input
+ circuit.OutputHashes[i] = p.cache.outputHashes[input].hash
}
// private
- witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
+ circuit.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
for i, bit := range p.cache.ballotBits {
- witness.BallotBits[i].Assign(bit)
+ circuit.BallotBits[i] = bit
}
- witness.Coeffs = make([]frontend.Variable, len(p.Coeffs))
+ circuit.Coeffs = make([]frontend.Variable, len(p.Coeffs))
for i, coeff := range p.Coeffs {
- witness.Coeffs[i].Assign(coeff)
+ circuit.Coeffs[i] = coeff
}
- witness.Constant.Assign(p.cache.constant)
+ circuit.Constant = p.cache.constant
+ }
+
+ var err error
+ p.cache.witness, err = frontend.NewWitness(circuit, ecc.BLS12_381)
+ if err != nil {
+ panic(fmt.Errorf("couldn't create witness: %w", err))
}
}
@@ 120,7 126,7 @@ func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey) (*
p.setupCache()
}
- proof, err := groth16.Prove(p.cache.r1cs, provingKey, &p.cache.witness)
+ proof, err := groth16.Prove(p.cache.r1cs, provingKey, p.cache.witness)
if err != nil {
panic(err)
}
M math/zk.go => math/zk.go +10 -14
@@ 10,7 10,6 @@ import (
)
type EvalsCircuit struct {
- HashSeed string
Inputs []frontend.Variable `gnark:",public"`
OutputHashes []frontend.Variable `gnark:",public"`
@@ 21,7 20,7 @@ type EvalsCircuit struct {
Constant frontend.Variable `gnark:",secret"`
}
-func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
+func (circuit *EvalsCircuit) Define(api frontend.API) error {
if len(circuit.Inputs) != len(circuit.OutputHashes) {
return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)")
}
@@ 44,7 43,7 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
}
// hash
- mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381, api)
+ mimc, err := mimc.NewMiMC(api)
if err != nil {
return err
}
@@ 55,9 54,9 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
// prove the constant is valid (i.e. a string of bits separated by 7
// zeros)
// TODO: support more than 255 voters (limit from usage of byte)
- zero := api.Constant(0)
- shift := api.Constant(1 << 8)
- slot := api.Constant(1)
+ zero := frontend.Variable(0)
+ shift := frontend.Variable(1 << 8)
+ slot := frontend.Variable(1)
constructedConstant := zero
for i := len(circuit.BallotBits)-1; i >= 0; i-- {
bit := circuit.BallotBits[i]
@@ 70,9 69,8 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
return nil
}
-func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
+func EvalsCircuitR1CS(numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
var circuit EvalsCircuit
- circuit.HashSeed = hashSeed
circuit.Inputs = make([]frontend.Variable, numVoters)
circuit.OutputHashes = make([]frontend.Variable, numVoters)
circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates)
@@ 85,7 83,6 @@ func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (fronte
}
type SumCircuit struct {
- HashSeed string
HashSelects []frontend.Variable `gnark:",public"`
OutputHashes []frontend.Variable `gnark:",public"`
Sum frontend.Variable `gnark:",public"`
@@ 93,7 90,7 @@ type SumCircuit struct {
Outputs []frontend.Variable `gnark:",secret"`
}
-func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error {
+func (circuit *SumCircuit) Define(api frontend.API) error {
if len(circuit.HashSelects) != len(circuit.OutputHashes) {
return errors.New("len(circuit.HashSelects) != len(circuit.OutputHashes)")
}
@@ 102,9 99,9 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error {
}
// prove the secret outputs hash to public hashes
- zero := api.Constant(0)
+ zero := frontend.Variable(0)
- mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381, api)
+ mimc, err := mimc.NewMiMC(api)
if err != nil {
return err
}
@@ 131,9 128,8 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error {
return nil
}
-func SumCircuitR1CS(hashSeed string, numVoters int) (frontend.CompiledConstraintSystem, error) {
+func SumCircuitR1CS(numVoters int) (frontend.CompiledConstraintSystem, error) {
var circuit SumCircuit
- circuit.HashSeed = hashSeed
circuit.HashSelects = make([]frontend.Variable, numVoters)
circuit.OutputHashes = make([]frontend.Variable, numVoters)
circuit.Outputs = make([]frontend.Variable, numVoters)
M version.go => version.go +1 -1
@@ 1,3 1,3 @@
package tallyard
-const Version string = "v0.4.5"
+const Version string = "v0.5.0"