M cmd/tallyard/main.go => cmd/tallyard/main.go +6 -1
@@ 7,6 7,7 @@ import (
"os"
"time"
+ "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/kyoh86/xdg"
log "github.com/sirupsen/logrus"
"maunium.net/go/mautrix"
@@ 115,7 116,11 @@ func main() {
// user likely hit C-c
return
}
- el.LocalVoter.Poly = math.NewRandomPoly(*ballot, len(*el.FinalJoinIDs), &el.LocalVoter.Input)
+ inputs := make([]*fr.Element, len(*el.FinalJoinIDs))
+ for i, joinID := range *el.FinalJoinIDs {
+ inputs[i] = &el.Joins[joinID].Input
+ }
+ el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeeds(), inputs)
el.Save()
}
M election/election.go => election/election.go +19 -0
@@ 24,6 24,8 @@ type Election struct {
// Save election to disk. Set by the containing ElectionsMap at runtime
Save func() `json:"-"`
+
+ hashSeeds *[]string `json:"-"`
}
func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election {
@@ 53,3 55,20 @@ func (el *Election) UnmarshalJSON(b []byte) error {
}
return nil
}
+
+func (el *Election) GetHashSeeds() []string {
+ if el.hashSeeds != nil {
+ return *el.hashSeeds
+ }
+
+ if el.FinalJoinIDs == nil {
+ panic("GatHashSeeds called before election started")
+ }
+
+ hashSeeds := make([]string, len(*el.FinalJoinIDs))
+ for i, joinID := range *el.FinalJoinIDs {
+ hashSeeds[i] = el.Joins[joinID].HashSeed
+ }
+ el.hashSeeds = &hashSeeds
+ return hashSeeds
+}
M election/msg.go => election/msg.go +18 -28
@@ 101,8 101,6 @@ type Eval struct {
// public; used by everyone in summation proofs
OutputHash string `json:"output_hash"`
// encrypted for specific voter
- POutputHash string `json:"poutput_hash"`
- // encrypted for specific voter
Proof string `json:"proof"`
}
@@ 637,26 635,9 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
}
}
- // decrypt poutputHash
- var poutputHash []byte
- {
- encryptedPOutputHash, err := base64.StdEncoding.DecodeString(ourEval.POutputHash)
- if err != nil {
- warnf("couldn't decode the encrypted poutput hash for us: %s", err)
- return
- }
- var decryptNonce [24]byte
- copy(decryptNonce[:], encryptedPOutputHash[:24])
- poutputHash, ok = box.Open(nil, encryptedPOutputHash[24:], &decryptNonce, &voter.PubKey, &el.LocalVoter.PrivKey)
- if !ok {
- warnf("couldn't decrypt poutput hash for us")
- return
- }
- }
-
outputHashes := make([][]byte, len(*el.FinalJoinIDs))
- var err error
for i, eval := range content.Evals {
+ var err error
outputHashes[i], err = base64.StdEncoding.DecodeString(eval.OutputHash)
if err != nil {
warnf("we couldn't decode the %d th output hash", i)
@@ 666,21 647,30 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
// verify proof
{
- var publicCircuit math.EvalCircuit
- publicCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1)
+ var publicCircuit math.EvalsCircuit
+
+ // public
+ publicCircuit.HashSeeds = el.GetHashSeeds()
+ publicCircuit.Inputs = make([]frontend.Variable, len(*el.FinalJoinIDs))
+ for i, joinID := range *el.FinalJoinIDs {
+ publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input)
+ }
+ publicCircuit.OutputHashes = make([]frontend.Variable, len(*el.FinalJoinIDs))
+ for i, outputHash := range outputHashes {
+ publicCircuit.OutputHashes[i].Assign(outputHash)
+ }
+
+ // private
publicCircuit.BallotBits = make([]frontend.Variable, len(el.Candidates)*len(el.Candidates))
- publicCircuit.Input.Assign(&el.LocalVoter.Input)
- publicCircuit.Output.Assign(output)
- publicCircuit.PInput.Assign(&voter.Input)
- publicCircuit.POutputHash.Assign(poutputHash)
- publicCircuit.HashSeed = el.LocalVoter.HashSeed
+ publicCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1)
+
if el.LocalVoter.EvalVerifyingKey == nil {
// should never happen because we processed all keys
// events above
errorf("our evals verifying key is nil")
return
}
- err = groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit)
+ err := groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit)
if err != nil {
warnf("poly eval proof verification failed: %s", err)
return
M election/voter.go => election/voter.go +5 -57
@@ 8,10 8,7 @@ import (
"fmt"
"io"
- "github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
- "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc"
- "github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/frontend"
"golang.org/x/crypto/nacl/box"
@@ 173,11 170,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
evalVerifyingKey groth16.VerifyingKey
)
{
- var evalCircuit math.EvalCircuit
- evalCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1)
- evalCircuit.BallotBits = make([]frontend.Variable, len(el.Candidates)*len(el.Candidates))
- evalCircuit.HashSeed = el.LocalVoter.HashSeed
- r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &evalCircuit)
+ r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs), len(el.Candidates))
if err != nil {
return fmt.Errorf("couldn't compile eval circuit: %s", err)
}
@@ 206,18 199,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
sumVerifyingKey groth16.VerifyingKey
)
{
- var sumCircuit math.SumCircuit
- numVoters := len(*el.FinalJoinIDs)
- hashSeeds := make([]string, numVoters)
- for i, joinID := range *el.FinalJoinIDs {
- hashSeeds[i] = el.Joins[joinID].HashSeed
- }
- sumCircuit.HashSeeds = hashSeeds
- sumCircuit.HashSelects = make([]frontend.Variable, numVoters)
- sumCircuit.OutputHashes = make([]frontend.Variable, numVoters)
- sumCircuit.Outputs = make([]frontend.Variable, numVoters)
-
- r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &sumCircuit)
+ r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))
if err != nil {
return fmt.Errorf("couldn't compile sum circuit: %s", err)
}
@@ 287,7 269,7 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
for i, joinID := range *el.FinalJoinIDs {
voter := el.Joins[joinID]
- output, poutputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey, voter.HashSeed)
+ output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey)
// encrypt output
var outputBytes [32]byte
@@ 302,24 284,6 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
encryptedOutput = box.Seal(outputNonce[:], outputBytes[:], &outputNonce, &voter.PubKey, &el.LocalVoter.PrivKey)
}
- // hash output
- outputHash, err := mimc.Sum(voter.HashSeed, outputBytes[:])
- if err != nil {
- el.RUnlock()
- return fmt.Errorf("couldn't hash output: %s", err)
- }
-
- // hash and encrypt poutput
- var encryptedPOutputHash []byte
- {
- var poutputNonce [24]byte
- if _, err := io.ReadFull(rand.Reader, poutputNonce[:]); err != nil {
- el.RUnlock()
- return fmt.Errorf("couldn't read random bytes for poutput nonce: %s", err)
- }
- encryptedPOutputHash = box.Seal(poutputNonce[:], poutputHash, &poutputNonce, &voter.PubKey, &el.LocalVoter.PrivKey)
- }
-
// encrypt proof
var encryptedProof []byte
{
@@ 340,7 304,6 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
evals[i] = Eval{
Output: base64.StdEncoding.EncodeToString(encryptedOutput),
OutputHash: base64.StdEncoding.EncodeToString(outputHash),
- POutputHash: base64.StdEncoding.EncodeToString(encryptedPOutputHash),
Proof: base64.StdEncoding.EncodeToString(encryptedProof),
}
@@ 377,27 340,12 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
numVoters := len(*el.FinalJoinIDs)
evalsIDs := make([]id.EventID, numVoters)
- hashSeeds := make([]string, numVoters)
for i, joinID := range *el.FinalJoinIDs {
voter := el.Joins[joinID]
evalsIDs[i] = *voter.EvalsID
- hashSeeds[i] = voter.HashSeed
}
- var r1cs frontend.CompiledConstraintSystem
- {
- var circuit math.SumCircuit
- circuit.HashSeeds = hashSeeds
- circuit.HashSelects = make([]frontend.Variable, numVoters)
- circuit.OutputHashes = make([]frontend.Variable, numVoters)
- circuit.Outputs = make([]frontend.Variable, numVoters)
- var err error
- r1cs, err = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
- if err != nil {
- el.RUnlock()
- return fmt.Errorf("couldn't compile sum circuit: %s", err)
- }
- }
+ r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))
proofs := make([]string, numVoters)
{
@@ 406,7 354,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
outputs := make([]frontend.Variable, numVoters)
var witness math.SumCircuit
- witness.HashSeeds = hashSeeds
+ witness.HashSeeds = el.GetHashSeeds()
witness.HashSelects = hashSelects
witness.OutputHashes = outputHashes
witness.Sum.Assign(el.LocalVoter.Sum)
M math/poly.go => math/poly.go +52 -72
@@ 3,39 3,38 @@ package math
import (
"fmt"
- "github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc"
- "github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/frontend"
)
type Poly struct {
- Ballot [][]byte `json:"ballot"`
- Coeffs []*fr.Element `json:"coeffs"`
- PInput *fr.Element `json:"pinput"`
+ Ballot [][]byte `json:"ballot"`
+ Coeffs []*fr.Element `json:"coeffs"`
+ HashSeeds []string `json:"hash_seeds"`
+ Inputs []*fr.Element `json:"inputs"`
cache *polyCache `json:"-"`
}
type polyCache struct {
- ballotBits []*fr.Element
- constant *fr.Element
- outputs map[*fr.Element]outputCache
- poutput *fr.Element
+ ballotBits []*fr.Element
+ constant *fr.Element
+ outputHashes map[*fr.Element]outputCache
+ r1cs frontend.CompiledConstraintSystem
+ witness EvalsCircuit
}
-type outputCache struct{
- output *fr.Element
- poutputHash []byte
- proof groth16.Proof
+type outputCache struct {
+ hash []byte
+ output *fr.Element
}
// NewRandomPoly generates a random polynomial of degree numVoters-1, using
// `ballot' as the constant term
-func NewRandomPoly(ballot [][]byte, numVoters int, pinput *fr.Element) *Poly {
- coeffs := make([]*fr.Element, numVoters - 1)
+func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *Poly {
+ coeffs := make([]*fr.Element, len(inputs) - 1)
var err error
for i := range coeffs {
@@ 45,7 44,7 @@ func NewRandomPoly(ballot [][]byte, numVoters int, pinput *fr.Element) *Poly {
}
}
- return &Poly{ballot, coeffs, pinput, nil}
+ return &Poly{ballot, coeffs, hashSeeds, inputs, nil}
}
func (p *Poly) setupCache() {
@@ 67,85 66,66 @@ func (p *Poly) setupCache() {
}
}
}
+
p.cache.ballotBits = ballotBits
p.cache.constant = new(fr.Element).SetBytes(ballot1DBytes)
- // malloc outputs map
- p.cache.outputs = make(map[*fr.Element]outputCache)
-
- // calculate verifier's output
- poutput := new(fr.Element).Set(p.cache.constant)
- term := new(fr.Element)
- for i, coeff := range p.Coeffs {
- term.Set(coeff)
- // <= because we want there to be at least one multiplication
- for j := 0; j <= i; j++ {
- term.Mul(term, p.PInput)
+ // calculate outputs and their hashes
+ p.cache.outputHashes = make(map[*fr.Element]outputCache)
+ for i, input := range p.Inputs {
+ output := p.eval(input)
+ byts := output.Bytes()
+ hash, err := mimc.Sum(p.HashSeeds[i], byts[:])
+ if err != nil {
+ panic(err)
}
- poutput.Add(poutput, term)
- }
- p.cache.poutput = poutput
-
-}
-
-func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey, hashSeed string) (*fr.Element, []byte, groth16.Proof) {
- if p.cache == nil {
- p.setupCache()
- } else if outputCache, exists := p.cache.outputs[input]; exists {
- return outputCache.output, outputCache.poutputHash, outputCache.proof
+ p.cache.outputHashes[input] = outputCache{hash, output}
}
// compile R1CS
- var r1cs frontend.CompiledConstraintSystem
{
- var circuit EvalCircuit
- circuit.Coeffs = make([]frontend.Variable, len(p.Coeffs))
- circuit.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
- circuit.HashSeed = hashSeed
var err error
- r1cs, err = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
+ p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeeds, len(p.Inputs), len(p.Ballot))
if err != nil {
panic(err)
}
}
- output := p.eval(input)
-
- var witness EvalCircuit
- // public
- witness.Input.Assign(input)
- witness.Output.Assign(output)
- witness.PInput.Assign(p.PInput)
- var poutputHash []byte
+ // create witness
{
- poutputBytes := p.cache.poutput.Bytes()
- var err error
- poutputHash, err = mimc.Sum(hashSeed, poutputBytes[:])
- if err != nil {
- panic(err)
+ witness := &p.cache.witness
+ // public
+ witness.HashSeeds = p.HashSeeds
+ witness.Inputs = make([]frontend.Variable, len(p.Inputs))
+ witness.OutputHashes = make([]frontend.Variable, len(p.Inputs))
+ for i, input := range p.Inputs {
+ witness.Inputs[i].Assign(input)
+ witness.OutputHashes[i].Assign(p.cache.outputHashes[input].hash)
}
- witness.POutputHash.Assign(poutputHash)
- }
- witness.HashSeed = hashSeed
- // private
- witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
- for i, bit := range p.cache.ballotBits {
- witness.BallotBits[i].Assign(bit)
+ // private
+ witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
+ for i, bit := range p.cache.ballotBits {
+ witness.BallotBits[i].Assign(bit)
+ }
+ witness.Coeffs = make([]frontend.Variable, len(p.Coeffs))
+ for i, coeff := range p.Coeffs {
+ witness.Coeffs[i].Assign(coeff)
+ }
+ witness.Constant.Assign(p.cache.constant)
}
- witness.Coeffs = make([]frontend.Variable, len(p.Coeffs))
- for i, coeff := range p.Coeffs {
- witness.Coeffs[i].Assign(coeff)
+}
+
+func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey) (*fr.Element, []byte, groth16.Proof) {
+ if p.cache == nil {
+ p.setupCache()
}
- witness.Constant.Assign(p.cache.constant)
- witness.POutput.Assign(p.cache.poutput)
- proof, err := groth16.Prove(r1cs, provingKey, &witness)
+ proof, err := groth16.Prove(p.cache.r1cs, provingKey, &p.cache.witness)
if err != nil {
panic(err)
}
- p.cache.outputs[input] = outputCache{output, poutputHash, proof}
- return output, poutputHash, proof
+ return p.cache.outputHashes[input].output, p.cache.outputHashes[input].hash, proof
}
func (p *Poly) eval(input *fr.Element) *fr.Element {
M math/zk.go => math/zk.go +61 -30
@@ 4,44 4,55 @@ import (
"errors"
"github.com/consensys/gnark-crypto/ecc"
+ "github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash/mimc"
)
-type EvalCircuit struct {
- Input frontend.Variable `gnark:",public"`
- Output frontend.Variable `gnark:",public"`
- // prover's input
- PInput frontend.Variable `gnark:",public"`
- POutputHash frontend.Variable `gnark:",public"`
- HashSeed string
+type EvalsCircuit struct {
+ HashSeeds []string
+ Inputs []frontend.Variable `gnark:",public"`
+ OutputHashes []frontend.Variable `gnark:",public"`
// must be created with length numCandidates^2
BallotBits []frontend.Variable `gnark:",private"`
// must be created with length numVoters - 1
Coeffs []frontend.Variable `gnark:",private"`
Constant frontend.Variable `gnark:",private"`
- // prover's hidden output for their own polynomial
- POutput frontend.Variable `gnark:",private"`
}
-func (circuit *EvalCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
- // prove the input yields the given output
- output := circuit.Constant
- poutput := circuit.Constant
- for i, coeff := range circuit.Coeffs {
- term := coeff
- vterm := coeff
- // <= because we want there to be at least one multiplication
- for j := 0; j <= i; j++ {
- term = cs.Mul(term, circuit.Input)
- vterm = cs.Mul(vterm, circuit.PInput)
+func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
+ if len(circuit.HashSeeds) != len(circuit.Inputs) {
+ return errors.New("len(circuit.HashSeeds) != len(circuit.Inputs)")
+ }
+ if len(circuit.Inputs) != len(circuit.OutputHashes) {
+ return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)")
+ }
+ if len(circuit.OutputHashes)-1 != len(circuit.Coeffs) {
+ return errors.New("len(circuit.OutputHashes)-1 != len(circuit.Coeffs)")
+ }
+
+ // prove the evaluation of Inputs[i] yields an output whose hash is
+ // OutputHashes[i]
+ for i, input := range circuit.Inputs {
+ // eval input
+ output := circuit.Constant
+ for j, coeff := range circuit.Coeffs {
+ term := coeff
+ // <= because we want there to be at least one multiplication
+ for k := 0; k <= j; k++ {
+ term = cs.Mul(term, input)
+ }
+ output = cs.Add(output, term)
+ }
+
+ // hash
+ mimc, err := mimc.NewMiMC(circuit.HashSeeds[i], ecc.BLS12_381)
+ if err != nil {
+ return err
}
- output = cs.Add(output, term)
- poutput = cs.Add(poutput, vterm)
+ cs.AssertIsEqual(mimc.Hash(cs, output), circuit.OutputHashes[i])
}
- cs.AssertIsEqual(output, circuit.Output)
- cs.AssertIsEqual(poutput, circuit.POutput)
// prove the constant is valid (i.e. a string of bits separated by 7
// zeros)
@@ 58,14 69,21 @@ func (circuit *EvalCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem
}
cs.AssertIsEqual(constructedConstant, circuit.Constant)
- mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381)
+ return nil
+}
+
+func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
+ var circuit EvalsCircuit
+ circuit.HashSeeds = hashSeeds
+ circuit.Inputs = make([]frontend.Variable, numVoters)
+ circuit.OutputHashes = make([]frontend.Variable, numVoters)
+ circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates)
+ circuit.Coeffs = make([]frontend.Variable, numVoters-1)
+ r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
if err != nil {
- return err
+ return nil, err
}
-
- cs.AssertIsEqual(mimc.Hash(cs, circuit.POutput), circuit.POutputHash)
-
- return nil
+ return r1cs, nil
}
type SumCircuit struct {
@@ 115,3 133,16 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem)
return nil
}
+
+func SumCircuitR1CS(hashSeeds []string, numVoters int) (frontend.CompiledConstraintSystem, error) {
+ var circuit SumCircuit
+ circuit.HashSeeds = hashSeeds
+ circuit.HashSelects = make([]frontend.Variable, numVoters)
+ circuit.OutputHashes = make([]frontend.Variable, numVoters)
+ circuit.Outputs = make([]frontend.Variable, numVoters)
+ r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
+ if err != nil {
+ return nil, err
+ }
+ return r1cs, nil
+}