From c46167bbb2adf26dff2694d5ed25211e3f9d098a Mon Sep 17 00:00:00 2001 From: David Florness Date: Fri, 7 May 2021 19:27:49 -0400 Subject: [PATCH] Use one witness for evals circuits This ensures the coefficients for any given voter are used consistently for every input s/he evaluates. --- cmd/tallyard/main.go | 7 ++- election/election.go | 19 +++++++ election/msg.go | 46 +++++++--------- election/voter.go | 62 ++-------------------- math/poly.go | 124 ++++++++++++++++++------------------------- math/zk.go | 91 ++++++++++++++++++++----------- 6 files changed, 161 insertions(+), 188 deletions(-) diff --git a/cmd/tallyard/main.go b/cmd/tallyard/main.go index a643fea..6614179 100644 --- a/cmd/tallyard/main.go +++ b/cmd/tallyard/main.go @@ -7,6 +7,7 @@ import ( "os" "time" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/kyoh86/xdg" log "github.com/sirupsen/logrus" "maunium.net/go/mautrix" @@ -115,7 +116,11 @@ func main() { // user likely hit C-c return } - el.LocalVoter.Poly = math.NewRandomPoly(*ballot, len(*el.FinalJoinIDs), &el.LocalVoter.Input) + inputs := make([]*fr.Element, len(*el.FinalJoinIDs)) + for i, joinID := range *el.FinalJoinIDs { + inputs[i] = &el.Joins[joinID].Input + } + el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeeds(), inputs) el.Save() } diff --git a/election/election.go b/election/election.go index f6b8993..ce03213 100644 --- a/election/election.go +++ b/election/election.go @@ -24,6 +24,8 @@ type Election struct { // Save election to disk. Set by the containing ElectionsMap at runtime Save func() `json:"-"` + + hashSeeds *[]string `json:"-"` } func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election { @@ -53,3 +55,20 @@ func (el *Election) UnmarshalJSON(b []byte) error { } return nil } + +func (el *Election) GetHashSeeds() []string { + if el.hashSeeds != nil { + return *el.hashSeeds + } + + if el.FinalJoinIDs == nil { + panic("GatHashSeeds called before election started") + } + + hashSeeds := make([]string, len(*el.FinalJoinIDs)) + for i, joinID := range *el.FinalJoinIDs { + hashSeeds[i] = el.Joins[joinID].HashSeed + } + el.hashSeeds = &hashSeeds + return hashSeeds +} diff --git a/election/msg.go b/election/msg.go index 2a44d43..d9f39e0 100644 --- a/election/msg.go +++ b/election/msg.go @@ -101,8 +101,6 @@ type Eval struct { // public; used by everyone in summation proofs OutputHash string `json:"output_hash"` // encrypted for specific voter - POutputHash string `json:"poutput_hash"` - // encrypted for specific voter Proof string `json:"proof"` } @@ -637,26 +635,9 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { } } - // decrypt poutputHash - var poutputHash []byte - { - encryptedPOutputHash, err := base64.StdEncoding.DecodeString(ourEval.POutputHash) - if err != nil { - warnf("couldn't decode the encrypted poutput hash for us: %s", err) - return - } - var decryptNonce [24]byte - copy(decryptNonce[:], encryptedPOutputHash[:24]) - poutputHash, ok = box.Open(nil, encryptedPOutputHash[24:], &decryptNonce, &voter.PubKey, &el.LocalVoter.PrivKey) - if !ok { - warnf("couldn't decrypt poutput hash for us") - return - } - } - outputHashes := make([][]byte, len(*el.FinalJoinIDs)) - var err error for i, eval := range content.Evals { + var err error outputHashes[i], err = base64.StdEncoding.DecodeString(eval.OutputHash) if err != nil { warnf("we couldn't decode the %d th output hash", i) @@ -666,21 +647,30 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { // verify proof { - var publicCircuit math.EvalCircuit - publicCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1) + var publicCircuit math.EvalsCircuit + + // public + publicCircuit.HashSeeds = el.GetHashSeeds() + publicCircuit.Inputs = make([]frontend.Variable, len(*el.FinalJoinIDs)) + for i, joinID := range *el.FinalJoinIDs { + publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input) + } + publicCircuit.OutputHashes = make([]frontend.Variable, len(*el.FinalJoinIDs)) + for i, outputHash := range outputHashes { + publicCircuit.OutputHashes[i].Assign(outputHash) + } + + // private publicCircuit.BallotBits = make([]frontend.Variable, len(el.Candidates)*len(el.Candidates)) - publicCircuit.Input.Assign(&el.LocalVoter.Input) - publicCircuit.Output.Assign(output) - publicCircuit.PInput.Assign(&voter.Input) - publicCircuit.POutputHash.Assign(poutputHash) - publicCircuit.HashSeed = el.LocalVoter.HashSeed + publicCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1) + if el.LocalVoter.EvalVerifyingKey == nil { // should never happen because we processed all keys // events above errorf("our evals verifying key is nil") return } - err = groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit) + err := groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit) if err != nil { warnf("poly eval proof verification failed: %s", err) return diff --git a/election/voter.go b/election/voter.go index abd0717..f3eb508 100644 --- a/election/voter.go +++ b/election/voter.go @@ -8,10 +8,7 @@ import ( "fmt" "io" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" "golang.org/x/crypto/nacl/box" @@ -173,11 +170,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto evalVerifyingKey groth16.VerifyingKey ) { - var evalCircuit math.EvalCircuit - evalCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1) - evalCircuit.BallotBits = make([]frontend.Variable, len(el.Candidates)*len(el.Candidates)) - evalCircuit.HashSeed = el.LocalVoter.HashSeed - r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &evalCircuit) + r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs), len(el.Candidates)) if err != nil { return fmt.Errorf("couldn't compile eval circuit: %s", err) } @@ -206,18 +199,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto sumVerifyingKey groth16.VerifyingKey ) { - var sumCircuit math.SumCircuit - numVoters := len(*el.FinalJoinIDs) - hashSeeds := make([]string, numVoters) - for i, joinID := range *el.FinalJoinIDs { - hashSeeds[i] = el.Joins[joinID].HashSeed - } - sumCircuit.HashSeeds = hashSeeds - sumCircuit.HashSelects = make([]frontend.Variable, numVoters) - sumCircuit.OutputHashes = make([]frontend.Variable, numVoters) - sumCircuit.Outputs = make([]frontend.Variable, numVoters) - - r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &sumCircuit) + r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs)) if err != nil { return fmt.Errorf("couldn't compile sum circuit: %s", err) } @@ -287,7 +269,7 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er for i, joinID := range *el.FinalJoinIDs { voter := el.Joins[joinID] - output, poutputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey, voter.HashSeed) + output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey) // encrypt output var outputBytes [32]byte @@ -302,24 +284,6 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er encryptedOutput = box.Seal(outputNonce[:], outputBytes[:], &outputNonce, &voter.PubKey, &el.LocalVoter.PrivKey) } - // hash output - outputHash, err := mimc.Sum(voter.HashSeed, outputBytes[:]) - if err != nil { - el.RUnlock() - return fmt.Errorf("couldn't hash output: %s", err) - } - - // hash and encrypt poutput - var encryptedPOutputHash []byte - { - var poutputNonce [24]byte - if _, err := io.ReadFull(rand.Reader, poutputNonce[:]); err != nil { - el.RUnlock() - return fmt.Errorf("couldn't read random bytes for poutput nonce: %s", err) - } - encryptedPOutputHash = box.Seal(poutputNonce[:], poutputHash, &poutputNonce, &voter.PubKey, &el.LocalVoter.PrivKey) - } - // encrypt proof var encryptedProof []byte { @@ -340,7 +304,6 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er evals[i] = Eval{ Output: base64.StdEncoding.EncodeToString(encryptedOutput), OutputHash: base64.StdEncoding.EncodeToString(outputHash), - POutputHash: base64.StdEncoding.EncodeToString(encryptedPOutputHash), Proof: base64.StdEncoding.EncodeToString(encryptedProof), } @@ -377,27 +340,12 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro numVoters := len(*el.FinalJoinIDs) evalsIDs := make([]id.EventID, numVoters) - hashSeeds := make([]string, numVoters) for i, joinID := range *el.FinalJoinIDs { voter := el.Joins[joinID] evalsIDs[i] = *voter.EvalsID - hashSeeds[i] = voter.HashSeed } - var r1cs frontend.CompiledConstraintSystem - { - var circuit math.SumCircuit - circuit.HashSeeds = hashSeeds - circuit.HashSelects = make([]frontend.Variable, numVoters) - circuit.OutputHashes = make([]frontend.Variable, numVoters) - circuit.Outputs = make([]frontend.Variable, numVoters) - var err error - r1cs, err = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit) - if err != nil { - el.RUnlock() - return fmt.Errorf("couldn't compile sum circuit: %s", err) - } - } + r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs)) proofs := make([]string, numVoters) { @@ -406,7 +354,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro outputs := make([]frontend.Variable, numVoters) var witness math.SumCircuit - witness.HashSeeds = hashSeeds + witness.HashSeeds = el.GetHashSeeds() witness.HashSelects = hashSelects witness.OutputHashes = outputHashes witness.Sum.Assign(el.LocalVoter.Sum) diff --git a/math/poly.go b/math/poly.go index 33f9c02..d4063c5 100644 --- a/math/poly.go +++ b/math/poly.go @@ -3,39 +3,38 @@ package math import ( "fmt" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" ) type Poly struct { - Ballot [][]byte `json:"ballot"` - Coeffs []*fr.Element `json:"coeffs"` - PInput *fr.Element `json:"pinput"` + Ballot [][]byte `json:"ballot"` + Coeffs []*fr.Element `json:"coeffs"` + HashSeeds []string `json:"hash_seeds"` + Inputs []*fr.Element `json:"inputs"` cache *polyCache `json:"-"` } type polyCache struct { - ballotBits []*fr.Element - constant *fr.Element - outputs map[*fr.Element]outputCache - poutput *fr.Element + ballotBits []*fr.Element + constant *fr.Element + outputHashes map[*fr.Element]outputCache + r1cs frontend.CompiledConstraintSystem + witness EvalsCircuit } -type outputCache struct{ - output *fr.Element - poutputHash []byte - proof groth16.Proof +type outputCache struct { + hash []byte + output *fr.Element } // NewRandomPoly generates a random polynomial of degree numVoters-1, using // `ballot' as the constant term -func NewRandomPoly(ballot [][]byte, numVoters int, pinput *fr.Element) *Poly { - coeffs := make([]*fr.Element, numVoters - 1) +func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *Poly { + coeffs := make([]*fr.Element, len(inputs) - 1) var err error for i := range coeffs { @@ -45,7 +44,7 @@ func NewRandomPoly(ballot [][]byte, numVoters int, pinput *fr.Element) *Poly { } } - return &Poly{ballot, coeffs, pinput, nil} + return &Poly{ballot, coeffs, hashSeeds, inputs, nil} } func (p *Poly) setupCache() { @@ -67,85 +66,66 @@ func (p *Poly) setupCache() { } } } + p.cache.ballotBits = ballotBits p.cache.constant = new(fr.Element).SetBytes(ballot1DBytes) - // malloc outputs map - p.cache.outputs = make(map[*fr.Element]outputCache) - - // calculate verifier's output - poutput := new(fr.Element).Set(p.cache.constant) - term := new(fr.Element) - for i, coeff := range p.Coeffs { - term.Set(coeff) - // <= because we want there to be at least one multiplication - for j := 0; j <= i; j++ { - term.Mul(term, p.PInput) + // calculate outputs and their hashes + p.cache.outputHashes = make(map[*fr.Element]outputCache) + for i, input := range p.Inputs { + output := p.eval(input) + byts := output.Bytes() + hash, err := mimc.Sum(p.HashSeeds[i], byts[:]) + if err != nil { + panic(err) } - poutput.Add(poutput, term) - } - p.cache.poutput = poutput - -} - -func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey, hashSeed string) (*fr.Element, []byte, groth16.Proof) { - if p.cache == nil { - p.setupCache() - } else if outputCache, exists := p.cache.outputs[input]; exists { - return outputCache.output, outputCache.poutputHash, outputCache.proof + p.cache.outputHashes[input] = outputCache{hash, output} } // compile R1CS - var r1cs frontend.CompiledConstraintSystem { - var circuit EvalCircuit - circuit.Coeffs = make([]frontend.Variable, len(p.Coeffs)) - circuit.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0])) - circuit.HashSeed = hashSeed var err error - r1cs, err = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit) + p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeeds, len(p.Inputs), len(p.Ballot)) if err != nil { panic(err) } } - output := p.eval(input) - - var witness EvalCircuit - // public - witness.Input.Assign(input) - witness.Output.Assign(output) - witness.PInput.Assign(p.PInput) - var poutputHash []byte + // create witness { - poutputBytes := p.cache.poutput.Bytes() - var err error - poutputHash, err = mimc.Sum(hashSeed, poutputBytes[:]) - if err != nil { - panic(err) + witness := &p.cache.witness + // public + witness.HashSeeds = p.HashSeeds + witness.Inputs = make([]frontend.Variable, len(p.Inputs)) + witness.OutputHashes = make([]frontend.Variable, len(p.Inputs)) + for i, input := range p.Inputs { + witness.Inputs[i].Assign(input) + witness.OutputHashes[i].Assign(p.cache.outputHashes[input].hash) } - witness.POutputHash.Assign(poutputHash) - } - witness.HashSeed = hashSeed - // private - witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0])) - for i, bit := range p.cache.ballotBits { - witness.BallotBits[i].Assign(bit) + // private + witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0])) + for i, bit := range p.cache.ballotBits { + witness.BallotBits[i].Assign(bit) + } + witness.Coeffs = make([]frontend.Variable, len(p.Coeffs)) + for i, coeff := range p.Coeffs { + witness.Coeffs[i].Assign(coeff) + } + witness.Constant.Assign(p.cache.constant) } - witness.Coeffs = make([]frontend.Variable, len(p.Coeffs)) - for i, coeff := range p.Coeffs { - witness.Coeffs[i].Assign(coeff) +} + +func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey) (*fr.Element, []byte, groth16.Proof) { + if p.cache == nil { + p.setupCache() } - witness.Constant.Assign(p.cache.constant) - witness.POutput.Assign(p.cache.poutput) - proof, err := groth16.Prove(r1cs, provingKey, &witness) + proof, err := groth16.Prove(p.cache.r1cs, provingKey, &p.cache.witness) if err != nil { panic(err) } - p.cache.outputs[input] = outputCache{output, poutputHash, proof} - return output, poutputHash, proof + return p.cache.outputHashes[input].output, p.cache.outputHashes[input].hash, proof } func (p *Poly) eval(input *fr.Element) *fr.Element { diff --git a/math/zk.go b/math/zk.go index d584cbf..372a713 100644 --- a/math/zk.go +++ b/math/zk.go @@ -4,44 +4,55 @@ import ( "errors" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash/mimc" ) -type EvalCircuit struct { - Input frontend.Variable `gnark:",public"` - Output frontend.Variable `gnark:",public"` - // prover's input - PInput frontend.Variable `gnark:",public"` - POutputHash frontend.Variable `gnark:",public"` - HashSeed string +type EvalsCircuit struct { + HashSeeds []string + Inputs []frontend.Variable `gnark:",public"` + OutputHashes []frontend.Variable `gnark:",public"` // must be created with length numCandidates^2 BallotBits []frontend.Variable `gnark:",private"` // must be created with length numVoters - 1 Coeffs []frontend.Variable `gnark:",private"` Constant frontend.Variable `gnark:",private"` - // prover's hidden output for their own polynomial - POutput frontend.Variable `gnark:",private"` } -func (circuit *EvalCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { - // prove the input yields the given output - output := circuit.Constant - poutput := circuit.Constant - for i, coeff := range circuit.Coeffs { - term := coeff - vterm := coeff - // <= because we want there to be at least one multiplication - for j := 0; j <= i; j++ { - term = cs.Mul(term, circuit.Input) - vterm = cs.Mul(vterm, circuit.PInput) +func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + if len(circuit.HashSeeds) != len(circuit.Inputs) { + return errors.New("len(circuit.HashSeeds) != len(circuit.Inputs)") + } + if len(circuit.Inputs) != len(circuit.OutputHashes) { + return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)") + } + if len(circuit.OutputHashes)-1 != len(circuit.Coeffs) { + return errors.New("len(circuit.OutputHashes)-1 != len(circuit.Coeffs)") + } + + // prove the evaluation of Inputs[i] yields an output whose hash is + // OutputHashes[i] + for i, input := range circuit.Inputs { + // eval input + output := circuit.Constant + for j, coeff := range circuit.Coeffs { + term := coeff + // <= because we want there to be at least one multiplication + for k := 0; k <= j; k++ { + term = cs.Mul(term, input) + } + output = cs.Add(output, term) + } + + // hash + mimc, err := mimc.NewMiMC(circuit.HashSeeds[i], ecc.BLS12_381) + if err != nil { + return err } - output = cs.Add(output, term) - poutput = cs.Add(poutput, vterm) + cs.AssertIsEqual(mimc.Hash(cs, output), circuit.OutputHashes[i]) } - cs.AssertIsEqual(output, circuit.Output) - cs.AssertIsEqual(poutput, circuit.POutput) // prove the constant is valid (i.e. a string of bits separated by 7 // zeros) @@ -58,14 +69,21 @@ func (circuit *EvalCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem } cs.AssertIsEqual(constructedConstant, circuit.Constant) - mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381) + return nil +} + +func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) { + var circuit EvalsCircuit + circuit.HashSeeds = hashSeeds + circuit.Inputs = make([]frontend.Variable, numVoters) + circuit.OutputHashes = make([]frontend.Variable, numVoters) + circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates) + circuit.Coeffs = make([]frontend.Variable, numVoters-1) + r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit) if err != nil { - return err + return nil, err } - - cs.AssertIsEqual(mimc.Hash(cs, circuit.POutput), circuit.POutputHash) - - return nil + return r1cs, nil } type SumCircuit struct { @@ -115,3 +133,16 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) return nil } + +func SumCircuitR1CS(hashSeeds []string, numVoters int) (frontend.CompiledConstraintSystem, error) { + var circuit SumCircuit + circuit.HashSeeds = hashSeeds + circuit.HashSelects = make([]frontend.Variable, numVoters) + circuit.OutputHashes = make([]frontend.Variable, numVoters) + circuit.Outputs = make([]frontend.Variable, numVoters) + r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit) + if err != nil { + return nil, err + } + return r1cs, nil +} -- 2.38.4