From d6670588a8acc712ba0aa2637464fc40af73dec2 Mon Sep 17 00:00:00 2001 From: David Florness Date: Sat, 19 Jun 2021 12:45:15 -0400 Subject: [PATCH] zk: derive one hash seed by combining each "part" sent by voters This slightly simplifies the circuits and protects against voters picking advantageous seeds [0]. [0]: Hopefully. This seems evident to me, but I'm not a cryptographer :) --- cmd/tallyard/main.go | 2 +- election/election.go | 18 ++++++++++-------- election/marshal_test.go | 27 +++++++++++++++------------ election/msg.go | 24 ++++++++++++------------ election/voter.go | 16 ++++++++-------- math/hash.go | 11 +++++++++++ math/lagrange_test.go | 13 ++++++------- math/poly.go | 20 ++++++++++---------- math/zk.go | 32 +++++++++++++------------------- 9 files changed, 86 insertions(+), 77 deletions(-) create mode 100644 math/hash.go diff --git a/cmd/tallyard/main.go b/cmd/tallyard/main.go index b79964c..7763a24 100644 --- a/cmd/tallyard/main.go +++ b/cmd/tallyard/main.go @@ -128,7 +128,7 @@ func main() { for i, joinID := range *el.FinalJoinIDs { inputs[i] = &el.Joins[joinID].Input } - el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeeds(), inputs) + el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeed(), inputs) el.Save() } diff --git a/election/election.go b/election/election.go index ce03213..2acaa3d 100644 --- a/election/election.go +++ b/election/election.go @@ -7,6 +7,7 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "tallyard.xyz/math" ) type Election struct { @@ -25,7 +26,7 @@ type Election struct { // Save election to disk. Set by the containing ElectionsMap at runtime Save func() `json:"-"` - hashSeeds *[]string `json:"-"` + hashSeed *string `json:"-"` } func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election { @@ -56,19 +57,20 @@ func (el *Election) UnmarshalJSON(b []byte) error { return nil } -func (el *Election) GetHashSeeds() []string { - if el.hashSeeds != nil { - return *el.hashSeeds +func (el *Election) GetHashSeed() string { + if el.hashSeed != nil { + return *el.hashSeed } if el.FinalJoinIDs == nil { panic("GatHashSeeds called before election started") } - hashSeeds := make([]string, len(*el.FinalJoinIDs)) + seedParts := make([][]byte, len(*el.FinalJoinIDs)) for i, joinID := range *el.FinalJoinIDs { - hashSeeds[i] = el.Joins[joinID].HashSeed + seedParts[i] = el.Joins[joinID].SeedPart } - el.hashSeeds = &hashSeeds - return hashSeeds + hashSeed := math.HashSeedFromSeedParts(seedParts) + el.hashSeed = &hashSeed + return hashSeed } diff --git a/election/marshal_test.go b/election/marshal_test.go index 4bcbca2..5758850 100644 --- a/election/marshal_test.go +++ b/election/marshal_test.go @@ -40,15 +40,16 @@ func randomLocalVoter(t *testing.T) *LocalVoter { pubKey, privKey, err := box.GenerateKey(rand.Reader) // HashSeed(s) - hashSeeds := make([]string, numVoters) - for i := range hashSeeds { - var hashSeed [32]byte - if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil { + seedParts := make([][]byte, numVoters) + for i := range seedParts { + var seedPart [32]byte + if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil { t.Fatal(err) } - hashSeeds[i] = base64.StdEncoding.EncodeToString(hashSeed[:]) + seedParts[i] = seedPart[:] } - hashSeed := hashSeeds[joinIDIndex] + + localHashSeed := seedParts[joinIDIndex] // JoinEvt joinEvt := &event.Event{ @@ -63,14 +64,16 @@ func randomLocalVoter(t *testing.T) *LocalVoter { CreateID: "createID", Input: base64.StdEncoding.EncodeToString(inputBytes[:]), PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]), - HashSeed: hashSeed, + SeedPart: base64.StdEncoding.EncodeToString(localHashSeed), }, }, } - localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, hashSeed), privKey) + localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, localHashSeed), privKey) localVoter.JoinIDIndex = &joinIDIndex + hashSeed := math.HashSeedFromSeedParts(seedParts) + // Output localVoter.Output, err = new(fr.Element).SetRandom() if err != nil { @@ -109,7 +112,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter { const numCandidates = 5 { - r1cs, err := math.EvalsCircuitR1CS(hashSeeds, numVoters, numCandidates) + r1cs, err := math.EvalsCircuitR1CS(hashSeed, numVoters, numCandidates) if err != nil { t.Fatal(err) } @@ -125,7 +128,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter { } { - r1cs, err := math.SumCircuitR1CS(hashSeeds, numVoters) + r1cs, err := math.SumCircuitR1CS(hashSeed, numVoters) if err != nil { t.Fatal(err) } @@ -148,7 +151,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter { } // Poly - localVoter.Poly = math.NewRandomPoly(ballot, hashSeeds, inputs) + localVoter.Poly = math.NewRandomPoly(ballot, hashSeed, inputs) return localVoter } @@ -184,7 +187,7 @@ func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) { } // HashSeed - if lv1.HashSeed != lv2.HashSeed { + if bytes.Compare(lv1.SeedPart, lv2.SeedPart) != 0 { t.Errorf("hash seeds not equal") } diff --git a/election/msg.go b/election/msg.go index a41593f..3951720 100644 --- a/election/msg.go +++ b/election/msg.go @@ -69,7 +69,7 @@ type JoinElectionContent struct { CreateID id.EventID `json:"create_id"` Input string `json:"input"` PubKey string `json:"pub_key"` - HashSeed string `json:"hash_seed"` + SeedPart string `json:"seed_part"` } type StartElectionContent struct { @@ -281,18 +281,18 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success return } - // HashSeed - if content.HashSeed == "" { - warnf("the hash seed is empty") + // SeedPart + if content.SeedPart == "" { + warnf("the seed part is empty") return } - hashSeedBytes, err := base64.StdEncoding.DecodeString(content.HashSeed) + seedPart, err := base64.StdEncoding.DecodeString(content.SeedPart) if err != nil { - warnf("we couldn't decode their hash seed: %s", err) + warnf("we couldn't decode their seed part: %s", err) return } - if len(hashSeedBytes) < 32 { - warnf("their hash seed is fewer than 32 bytes") + if len(seedPart) < 32 { + warnf("their seed part is fewer than 32 bytes") return } @@ -330,7 +330,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success defer el.Save() defer el.Unlock() - el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, content.HashSeed) + el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, seedPart) return true } @@ -665,7 +665,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { var publicCircuit math.EvalsCircuit // public - publicCircuit.HashSeeds = el.GetHashSeeds() + publicCircuit.HashSeed = el.GetHashSeed() publicCircuit.Inputs = make([]frontend.Variable, len(*el.FinalJoinIDs)) for i, joinID := range *el.FinalJoinIDs { publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input) @@ -695,7 +695,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { // ensure our output hashes to the given hash { outputBytes := output.Bytes() - outputHash, err := mimc.Sum(el.LocalVoter.HashSeed, outputBytes[:]) + outputHash, err := mimc.Sum(el.GetHashSeed(), outputBytes[:]) if err != nil { errorf("couldn't hash output: %s", err) return @@ -847,7 +847,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) { } outputHashes[i].Assign((*evaler.OutputHashes)[voterJoinIDIndex]) } - publicCircuit.HashSeeds = el.GetHashSeeds() + publicCircuit.HashSeed = el.GetHashSeed() publicCircuit.HashSelects = hashSelects publicCircuit.OutputHashes = outputHashes publicCircuit.Sum.Assign(sum) diff --git a/election/voter.go b/election/voter.go index a5556c7..32549c8 100644 --- a/election/voter.go +++ b/election/voter.go @@ -23,7 +23,7 @@ type Voter struct { Input fr.Element `json:"input"` JoinEvt event.Event `json:"join_evt"` PubKey [32]byte `json:"pub_key"` - HashSeed string `json:"hash_seed"` + SeedPart []byte `json:"seed_part"` JoinIDIndex *uint `json:"join_id_index,omitempty"` Output *fr.Element `json:"output,omitempty"` @@ -47,12 +47,12 @@ type LocalVoter struct { SumVerifyingKey *MarshallableVerifyingKey `json:"sum_verifying_key,omitempty"` } -func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed string) *Voter { +func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, seedPart []byte) *Voter { return &Voter{ Input: *new(fr.Element).Set(input), JoinEvt: *joinEvt, PubKey: *pubKey, - HashSeed: hashSeed, + SeedPart: seedPart, } } @@ -104,7 +104,7 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore) CreateID: el.CreateEvt.ID, Input: base64.StdEncoding.EncodeToString(inputBytes[:]), PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]), - HashSeed: base64.StdEncoding.EncodeToString(seedBytes[:]), + SeedPart: base64.StdEncoding.EncodeToString(seedBytes[:]), }) if err != nil { return fmt.Errorf("couldn't send join messages: %s", err) @@ -171,7 +171,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto evalVerifyingKey groth16.VerifyingKey ) { - r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs), len(el.Candidates)) + r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs), len(el.Candidates)) if err != nil { return fmt.Errorf("couldn't compile eval circuit: %s", err) } @@ -200,7 +200,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto sumVerifyingKey groth16.VerifyingKey ) { - r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs)) + r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs)) if err != nil { return fmt.Errorf("couldn't compile sum circuit: %s", err) } @@ -346,7 +346,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro evalsIDs[i] = *voter.EvalsID } - r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs)) + r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs)) proofs := make([]string, numVoters) { @@ -355,7 +355,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro outputs := make([]frontend.Variable, numVoters) var witness math.SumCircuit - witness.HashSeeds = el.GetHashSeeds() + witness.HashSeed = el.GetHashSeed() witness.HashSelects = hashSelects witness.OutputHashes = outputHashes witness.Sum.Assign(el.LocalVoter.Sum) diff --git a/math/hash.go b/math/hash.go new file mode 100644 index 0000000..0209d9b --- /dev/null +++ b/math/hash.go @@ -0,0 +1,11 @@ +package math + +import "golang.org/x/crypto/sha3" + +func HashSeedFromSeedParts(seedParts [][]byte) string { + res := sha3.Sum256(seedParts[0]) + for _, part := range seedParts[1:] { + res = sha3.Sum256(append(res[:], part...)) + } + return string(res[:]) +} diff --git a/math/lagrange_test.go b/math/lagrange_test.go index 5c94f2c..a96b391 100644 --- a/math/lagrange_test.go +++ b/math/lagrange_test.go @@ -2,7 +2,6 @@ package math import ( "crypto/rand" - "encoding/base64" "io" "testing" @@ -66,16 +65,16 @@ func TestRandomPoly(t *testing.T) { } // hash seeds don't matter - hashSeeds := make([]string, n) - for i := range hashSeeds { - var hashSeed [32]byte - if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil { + seedParts := make([][]byte, n) + for i := range seedParts { + var seedPart [32]byte + if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil { t.Fatal(err) } - hashSeeds[i] = base64.RawStdEncoding.EncodeToString(hashSeed[:]) + seedParts[i] = seedPart[:] } - poly := NewRandomPoly(ballot, hashSeeds, inputs) + poly := NewRandomPoly(ballot, HashSeedFromSeedParts(seedParts), inputs) points := make([]Point, n) for i := 0; i < n; i++ { genx: diff --git a/math/poly.go b/math/poly.go index d4063c5..0e0ae46 100644 --- a/math/poly.go +++ b/math/poly.go @@ -10,10 +10,10 @@ import ( ) type Poly struct { - Ballot [][]byte `json:"ballot"` - Coeffs []*fr.Element `json:"coeffs"` - HashSeeds []string `json:"hash_seeds"` - Inputs []*fr.Element `json:"inputs"` + Ballot [][]byte `json:"ballot"` + Coeffs []*fr.Element `json:"coeffs"` + HashSeed string `json:"hash_seed"` + Inputs []*fr.Element `json:"inputs"` cache *polyCache `json:"-"` } @@ -33,7 +33,7 @@ type outputCache struct { // NewRandomPoly generates a random polynomial of degree numVoters-1, using // `ballot' as the constant term -func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *Poly { +func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly { coeffs := make([]*fr.Element, len(inputs) - 1) var err error @@ -44,7 +44,7 @@ func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *P } } - return &Poly{ballot, coeffs, hashSeeds, inputs, nil} + return &Poly{ballot, coeffs, hashSeed, inputs, nil} } func (p *Poly) setupCache() { @@ -72,10 +72,10 @@ func (p *Poly) setupCache() { // calculate outputs and their hashes p.cache.outputHashes = make(map[*fr.Element]outputCache) - for i, input := range p.Inputs { + for _, input := range p.Inputs { output := p.eval(input) byts := output.Bytes() - hash, err := mimc.Sum(p.HashSeeds[i], byts[:]) + hash, err := mimc.Sum(p.HashSeed, byts[:]) if err != nil { panic(err) } @@ -85,7 +85,7 @@ func (p *Poly) setupCache() { // compile R1CS { var err error - p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeeds, len(p.Inputs), len(p.Ballot)) + p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeed, len(p.Inputs), len(p.Ballot)) if err != nil { panic(err) } @@ -95,7 +95,7 @@ func (p *Poly) setupCache() { { witness := &p.cache.witness // public - witness.HashSeeds = p.HashSeeds + witness.HashSeed = p.HashSeed witness.Inputs = make([]frontend.Variable, len(p.Inputs)) witness.OutputHashes = make([]frontend.Variable, len(p.Inputs)) for i, input := range p.Inputs { diff --git a/math/zk.go b/math/zk.go index 372a713..af2ae4e 100644 --- a/math/zk.go +++ b/math/zk.go @@ -10,7 +10,7 @@ import ( ) type EvalsCircuit struct { - HashSeeds []string + HashSeed string Inputs []frontend.Variable `gnark:",public"` OutputHashes []frontend.Variable `gnark:",public"` @@ -22,9 +22,6 @@ type EvalsCircuit struct { } func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { - if len(circuit.HashSeeds) != len(circuit.Inputs) { - return errors.New("len(circuit.HashSeeds) != len(circuit.Inputs)") - } if len(circuit.Inputs) != len(circuit.OutputHashes) { return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)") } @@ -47,7 +44,7 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSyste } // hash - mimc, err := mimc.NewMiMC(circuit.HashSeeds[i], ecc.BLS12_381) + mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381) if err != nil { return err } @@ -72,9 +69,9 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSyste return nil } -func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) { +func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) { var circuit EvalsCircuit - circuit.HashSeeds = hashSeeds + circuit.HashSeed = hashSeed circuit.Inputs = make([]frontend.Variable, numVoters) circuit.OutputHashes = make([]frontend.Variable, numVoters) circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates) @@ -87,7 +84,7 @@ func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (fro } type SumCircuit struct { - HashSeeds []string + HashSeed string HashSelects []frontend.Variable `gnark:",public"` OutputHashes []frontend.Variable `gnark:",public"` Sum frontend.Variable `gnark:",public"` @@ -96,9 +93,6 @@ type SumCircuit struct { } func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { - if len(circuit.HashSeeds) != len(circuit.HashSelects) { - return errors.New("len(circuit.HashSeeds) != len(circuit.HashSelects)") - } if len(circuit.HashSelects) != len(circuit.OutputHashes) { return errors.New("len(circuit.HashSelects) != len(circuit.OutputHashes)") } @@ -108,13 +102,13 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) // prove the secret outputs hash to public hashes zero := cs.Constant(0) + + mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381) + if err != nil { + return err + } // TODO: this seems verify inefficient - for i, hashSeed := range circuit.HashSeeds { - mimc, err := mimc.NewMiMC(hashSeed, ecc.BLS12_381) - if err != nil { - return err - } - bit := circuit.HashSelects[i] + for _, bit := range circuit.HashSelects { for j, output := range circuit.Outputs { cs.AssertIsEqual( cs.Select(bit, mimc.Hash(cs, output), zero), @@ -134,9 +128,9 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) return nil } -func SumCircuitR1CS(hashSeeds []string, numVoters int) (frontend.CompiledConstraintSystem, error) { +func SumCircuitR1CS(hashSeed string, numVoters int) (frontend.CompiledConstraintSystem, error) { var circuit SumCircuit - circuit.HashSeeds = hashSeeds + circuit.HashSeed = hashSeed circuit.HashSelects = make([]frontend.Variable, numVoters) circuit.OutputHashes = make([]frontend.Variable, numVoters) circuit.Outputs = make([]frontend.Variable, numVoters) -- 2.38.4