~edwargix/tallyard

c46167bbb2adf26dff2694d5ed25211e3f9d098a — David Florness 2 years ago 43b26a3
Use one witness for evals circuits

This ensures the coefficients for any given voter are used consistently for
every input s/he evaluates.
6 files changed, 161 insertions(+), 188 deletions(-)

M cmd/tallyard/main.go
M election/election.go
M election/msg.go
M election/voter.go
M math/poly.go
M math/zk.go
M cmd/tallyard/main.go => cmd/tallyard/main.go +6 -1
@@ 7,6 7,7 @@ import (
	"os"
	"time"

	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
	"github.com/kyoh86/xdg"
	log "github.com/sirupsen/logrus"
	"maunium.net/go/mautrix"


@@ 115,7 116,11 @@ func main() {
			// user likely hit C-c
			return
		}
		el.LocalVoter.Poly = math.NewRandomPoly(*ballot, len(*el.FinalJoinIDs), &el.LocalVoter.Input)
		inputs := make([]*fr.Element, len(*el.FinalJoinIDs))
		for i, joinID := range *el.FinalJoinIDs {
			inputs[i] = &el.Joins[joinID].Input
		}
		el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeeds(), inputs)
		el.Save()
	}


M election/election.go => election/election.go +19 -0
@@ 24,6 24,8 @@ type Election struct {

	// Save election to disk.  Set by the containing ElectionsMap at runtime
	Save         func()                `json:"-"`

	hashSeeds    *[]string             `json:"-"`
}

func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election {


@@ 53,3 55,20 @@ func (el *Election) UnmarshalJSON(b []byte) error {
	}
	return nil
}

func (el *Election) GetHashSeeds() []string {
	if el.hashSeeds != nil {
		return *el.hashSeeds
	}

	if el.FinalJoinIDs == nil {
		panic("GatHashSeeds called before election started")
	}

	hashSeeds := make([]string, len(*el.FinalJoinIDs))
	for i, joinID := range *el.FinalJoinIDs {
		hashSeeds[i] = el.Joins[joinID].HashSeed
	}
	el.hashSeeds = &hashSeeds
	return hashSeeds
}

M election/msg.go => election/msg.go +18 -28
@@ 101,8 101,6 @@ type Eval struct {
	// public; used by everyone in summation proofs
	OutputHash  string `json:"output_hash"`
	// encrypted for specific voter
	POutputHash string `json:"poutput_hash"`
	// encrypted for specific voter
	Proof       string `json:"proof"`
}



@@ 637,26 635,9 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
		}
	}

	// decrypt poutputHash
	var poutputHash []byte
	{
		encryptedPOutputHash, err := base64.StdEncoding.DecodeString(ourEval.POutputHash)
		if err != nil {
			warnf("couldn't decode the encrypted poutput hash for us: %s", err)
			return
		}
		var decryptNonce [24]byte
		copy(decryptNonce[:], encryptedPOutputHash[:24])
		poutputHash, ok = box.Open(nil, encryptedPOutputHash[24:], &decryptNonce, &voter.PubKey, &el.LocalVoter.PrivKey)
		if !ok {
			warnf("couldn't decrypt poutput hash for us")
			return
		}
	}

	outputHashes := make([][]byte, len(*el.FinalJoinIDs))
	var err error
	for i, eval := range content.Evals {
		var err error
		outputHashes[i], err = base64.StdEncoding.DecodeString(eval.OutputHash)
		if err != nil {
			warnf("we couldn't decode the %d th output hash", i)


@@ 666,21 647,30 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {

	// verify proof
	{
		var publicCircuit math.EvalCircuit
		publicCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1)
		var publicCircuit math.EvalsCircuit

		// public
		publicCircuit.HashSeeds = el.GetHashSeeds()
		publicCircuit.Inputs = make([]frontend.Variable, len(*el.FinalJoinIDs))
		for i, joinID := range *el.FinalJoinIDs {
			publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input)
		}
		publicCircuit.OutputHashes = make([]frontend.Variable, len(*el.FinalJoinIDs))
		for i, outputHash := range outputHashes {
			publicCircuit.OutputHashes[i].Assign(outputHash)
		}

		// private
		publicCircuit.BallotBits = make([]frontend.Variable, len(el.Candidates)*len(el.Candidates))
		publicCircuit.Input.Assign(&el.LocalVoter.Input)
		publicCircuit.Output.Assign(output)
		publicCircuit.PInput.Assign(&voter.Input)
		publicCircuit.POutputHash.Assign(poutputHash)
		publicCircuit.HashSeed = el.LocalVoter.HashSeed
		publicCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1)

		if el.LocalVoter.EvalVerifyingKey == nil {
			// should never happen because we processed all keys
			// events above
			errorf("our evals verifying key is nil")
			return
		}
		err = groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit)
		err := groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit)
		if err != nil {
			warnf("poly eval proof verification failed: %s", err)
			return

M election/voter.go => election/voter.go +5 -57
@@ 8,10 8,7 @@ import (
	"fmt"
	"io"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc"
	"github.com/consensys/gnark/backend"
	"github.com/consensys/gnark/backend/groth16"
	"github.com/consensys/gnark/frontend"
	"golang.org/x/crypto/nacl/box"


@@ 173,11 170,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
		evalVerifyingKey  groth16.VerifyingKey
	)
	{
		var evalCircuit math.EvalCircuit
		evalCircuit.Coeffs = make([]frontend.Variable, len(*el.FinalJoinIDs)-1)
		evalCircuit.BallotBits = make([]frontend.Variable, len(el.Candidates)*len(el.Candidates))
		evalCircuit.HashSeed = el.LocalVoter.HashSeed
		r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &evalCircuit)
		r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs), len(el.Candidates))
		if err != nil {
			return fmt.Errorf("couldn't compile eval circuit: %s", err)
		}


@@ 206,18 199,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
		sumVerifyingKey  groth16.VerifyingKey
	)
	{
		var sumCircuit math.SumCircuit
		numVoters := len(*el.FinalJoinIDs)
		hashSeeds := make([]string, numVoters)
		for i, joinID := range *el.FinalJoinIDs {
			hashSeeds[i] = el.Joins[joinID].HashSeed
		}
		sumCircuit.HashSeeds = hashSeeds
		sumCircuit.HashSelects = make([]frontend.Variable, numVoters)
		sumCircuit.OutputHashes = make([]frontend.Variable, numVoters)
		sumCircuit.Outputs = make([]frontend.Variable, numVoters)

		r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &sumCircuit)
		r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))
		if err != nil {
			return fmt.Errorf("couldn't compile sum circuit: %s", err)
		}


@@ 287,7 269,7 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
	for i, joinID := range *el.FinalJoinIDs {
		voter := el.Joins[joinID]

		output, poutputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey, voter.HashSeed)
		output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey)

		// encrypt output
		var outputBytes [32]byte


@@ 302,24 284,6 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
			encryptedOutput = box.Seal(outputNonce[:], outputBytes[:], &outputNonce, &voter.PubKey, &el.LocalVoter.PrivKey)
		}

		// hash output
		outputHash, err := mimc.Sum(voter.HashSeed, outputBytes[:])
		if err != nil {
			el.RUnlock()
			return fmt.Errorf("couldn't hash output: %s", err)
		}

		// hash and encrypt poutput
		var encryptedPOutputHash []byte
		{
			var poutputNonce [24]byte
			if _, err := io.ReadFull(rand.Reader, poutputNonce[:]); err != nil {
				el.RUnlock()
				return fmt.Errorf("couldn't read random bytes for poutput nonce: %s", err)
			}
			encryptedPOutputHash = box.Seal(poutputNonce[:], poutputHash, &poutputNonce, &voter.PubKey, &el.LocalVoter.PrivKey)
		}

		// encrypt proof
		var encryptedProof []byte
		{


@@ 340,7 304,6 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
		evals[i] = Eval{
			Output:      base64.StdEncoding.EncodeToString(encryptedOutput),
			OutputHash:  base64.StdEncoding.EncodeToString(outputHash),
			POutputHash: base64.StdEncoding.EncodeToString(encryptedPOutputHash),
			Proof:       base64.StdEncoding.EncodeToString(encryptedProof),
		}



@@ 377,27 340,12 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
	numVoters := len(*el.FinalJoinIDs)

	evalsIDs := make([]id.EventID, numVoters)
	hashSeeds := make([]string, numVoters)
	for i, joinID := range *el.FinalJoinIDs {
		voter := el.Joins[joinID]
		evalsIDs[i] = *voter.EvalsID
		hashSeeds[i] = voter.HashSeed
	}

	var r1cs frontend.CompiledConstraintSystem
	{
		var circuit math.SumCircuit
		circuit.HashSeeds = hashSeeds
		circuit.HashSelects = make([]frontend.Variable, numVoters)
		circuit.OutputHashes = make([]frontend.Variable, numVoters)
		circuit.Outputs = make([]frontend.Variable, numVoters)
		var err error
		r1cs, err = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
		if err != nil {
			el.RUnlock()
			return fmt.Errorf("couldn't compile sum circuit: %s", err)
		}
	}
	r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))

	proofs := make([]string, numVoters)
	{


@@ 406,7 354,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
		outputs := make([]frontend.Variable, numVoters)

		var witness math.SumCircuit
		witness.HashSeeds = hashSeeds
		witness.HashSeeds = el.GetHashSeeds()
		witness.HashSelects = hashSelects
		witness.OutputHashes = outputHashes
		witness.Sum.Assign(el.LocalVoter.Sum)

M math/poly.go => math/poly.go +52 -72
@@ 3,39 3,38 @@ package math
import (
	"fmt"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc"
	"github.com/consensys/gnark/backend"
	"github.com/consensys/gnark/backend/groth16"
	"github.com/consensys/gnark/frontend"
)

type Poly struct {
	Ballot [][]byte      `json:"ballot"`
	Coeffs []*fr.Element `json:"coeffs"`
	PInput *fr.Element   `json:"pinput"`
	Ballot    [][]byte      `json:"ballot"`
	Coeffs    []*fr.Element `json:"coeffs"`
	HashSeeds []string      `json:"hash_seeds"`
	Inputs    []*fr.Element `json:"inputs"`

	cache *polyCache `json:"-"`
}

type polyCache struct {
	ballotBits []*fr.Element
	constant   *fr.Element
	outputs    map[*fr.Element]outputCache
	poutput    *fr.Element
	ballotBits   []*fr.Element
	constant     *fr.Element
	outputHashes map[*fr.Element]outputCache
	r1cs         frontend.CompiledConstraintSystem
	witness      EvalsCircuit
}

type outputCache struct{
	output      *fr.Element
	poutputHash []byte
	proof       groth16.Proof
type outputCache struct {
	hash   []byte
	output *fr.Element
}

// NewRandomPoly generates a random polynomial of degree numVoters-1, using
// `ballot' as the constant term
func NewRandomPoly(ballot [][]byte, numVoters int, pinput *fr.Element) *Poly {
	coeffs := make([]*fr.Element, numVoters - 1)
func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *Poly {
	coeffs := make([]*fr.Element, len(inputs) - 1)

	var err error
	for i := range coeffs {


@@ 45,7 44,7 @@ func NewRandomPoly(ballot [][]byte, numVoters int, pinput *fr.Element) *Poly {
		}
	}

	return &Poly{ballot, coeffs, pinput, nil}
	return &Poly{ballot, coeffs, hashSeeds, inputs, nil}
}

func (p *Poly) setupCache() {


@@ 67,85 66,66 @@ func (p *Poly) setupCache() {
			}
		}
	}

	p.cache.ballotBits = ballotBits
	p.cache.constant = new(fr.Element).SetBytes(ballot1DBytes)

	// malloc outputs map
	p.cache.outputs = make(map[*fr.Element]outputCache)

	// calculate verifier's output
	poutput := new(fr.Element).Set(p.cache.constant)
	term := new(fr.Element)
	for i, coeff := range p.Coeffs {
		term.Set(coeff)
		// <= because we want there to be at least one multiplication
		for j := 0; j <= i; j++ {
			term.Mul(term, p.PInput)
	// calculate outputs and their hashes
	p.cache.outputHashes = make(map[*fr.Element]outputCache)
	for i, input := range p.Inputs {
		output := p.eval(input)
		byts := output.Bytes()
		hash, err := mimc.Sum(p.HashSeeds[i], byts[:])
		if err != nil {
			panic(err)
		}
		poutput.Add(poutput, term)
	}
	p.cache.poutput = poutput

}

func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey, hashSeed string) (*fr.Element, []byte, groth16.Proof) {
	if p.cache == nil {
		p.setupCache()
	} else if outputCache, exists := p.cache.outputs[input]; exists {
		return outputCache.output, outputCache.poutputHash, outputCache.proof
		p.cache.outputHashes[input] = outputCache{hash, output}
	}

	// compile R1CS
	var r1cs frontend.CompiledConstraintSystem
	{
		var circuit EvalCircuit
		circuit.Coeffs = make([]frontend.Variable, len(p.Coeffs))
		circuit.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
		circuit.HashSeed = hashSeed
		var err error
		r1cs, err = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
		p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeeds, len(p.Inputs), len(p.Ballot))
		if err != nil {
			panic(err)
		}
	}

	output := p.eval(input)

	var witness EvalCircuit
	// public
	witness.Input.Assign(input)
	witness.Output.Assign(output)
	witness.PInput.Assign(p.PInput)
	var poutputHash []byte
	// create witness
	{
		poutputBytes := p.cache.poutput.Bytes()
		var err error
		poutputHash, err = mimc.Sum(hashSeed, poutputBytes[:])
		if err != nil {
			panic(err)
		witness := &p.cache.witness
		// public
		witness.HashSeeds = p.HashSeeds
		witness.Inputs = make([]frontend.Variable, len(p.Inputs))
		witness.OutputHashes = make([]frontend.Variable, len(p.Inputs))
		for i, input := range p.Inputs {
			witness.Inputs[i].Assign(input)
			witness.OutputHashes[i].Assign(p.cache.outputHashes[input].hash)
		}
		witness.POutputHash.Assign(poutputHash)
	}
	witness.HashSeed = hashSeed
	// private
	witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
	for i, bit := range p.cache.ballotBits {
		witness.BallotBits[i].Assign(bit)
		// private
		witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
		for i, bit := range p.cache.ballotBits {
			witness.BallotBits[i].Assign(bit)
		}
		witness.Coeffs = make([]frontend.Variable, len(p.Coeffs))
		for i, coeff := range p.Coeffs {
			witness.Coeffs[i].Assign(coeff)
		}
		witness.Constant.Assign(p.cache.constant)
	}
	witness.Coeffs = make([]frontend.Variable, len(p.Coeffs))
	for i, coeff := range p.Coeffs {
		witness.Coeffs[i].Assign(coeff)
}

func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey) (*fr.Element, []byte, groth16.Proof) {
	if p.cache == nil {
		p.setupCache()
	}
	witness.Constant.Assign(p.cache.constant)
	witness.POutput.Assign(p.cache.poutput)

	proof, err := groth16.Prove(r1cs, provingKey, &witness)
	proof, err := groth16.Prove(p.cache.r1cs, provingKey, &p.cache.witness)
	if err != nil {
		panic(err)
	}

	p.cache.outputs[input] = outputCache{output, poutputHash, proof}
	return output, poutputHash, proof
	return p.cache.outputHashes[input].output, p.cache.outputHashes[input].hash, proof
}

func (p *Poly) eval(input *fr.Element) *fr.Element {

M math/zk.go => math/zk.go +61 -30
@@ 4,44 4,55 @@ import (
	"errors"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark/backend"
	"github.com/consensys/gnark/frontend"
	"github.com/consensys/gnark/std/hash/mimc"
)

type EvalCircuit struct {
	Input       frontend.Variable `gnark:",public"`
	Output      frontend.Variable `gnark:",public"`
	// prover's input
	PInput      frontend.Variable `gnark:",public"`
	POutputHash frontend.Variable `gnark:",public"`
	HashSeed    string
type EvalsCircuit struct {
	HashSeeds    []string
	Inputs       []frontend.Variable `gnark:",public"`
	OutputHashes []frontend.Variable `gnark:",public"`

	// must be created with length numCandidates^2
	BallotBits []frontend.Variable `gnark:",private"`
	// must be created with length numVoters - 1
	Coeffs     []frontend.Variable `gnark:",private"`
	Constant   frontend.Variable   `gnark:",private"`
	// prover's hidden output for their own polynomial
	POutput    frontend.Variable   `gnark:",private"`
}

func (circuit *EvalCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
	// prove the input yields the given output
	output := circuit.Constant
	poutput := circuit.Constant
	for i, coeff := range circuit.Coeffs {
		term := coeff
		vterm := coeff
		// <= because we want there to be at least one multiplication
		for j := 0; j <= i; j++ {
			term = cs.Mul(term, circuit.Input)
			vterm = cs.Mul(vterm, circuit.PInput)
func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
	if len(circuit.HashSeeds) != len(circuit.Inputs) {
		return errors.New("len(circuit.HashSeeds) != len(circuit.Inputs)")
	}
	if len(circuit.Inputs) != len(circuit.OutputHashes) {
		return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)")
	}
	if len(circuit.OutputHashes)-1 != len(circuit.Coeffs) {
		return errors.New("len(circuit.OutputHashes)-1 != len(circuit.Coeffs)")
	}

	// prove the evaluation of Inputs[i] yields an output whose hash is
	// OutputHashes[i]
	for i, input := range circuit.Inputs {
		// eval input
		output := circuit.Constant
		for j, coeff := range circuit.Coeffs {
			term := coeff
			// <= because we want there to be at least one multiplication
			for k := 0; k <= j; k++ {
				term = cs.Mul(term, input)
			}
			output = cs.Add(output, term)
		}

		// hash
		mimc, err := mimc.NewMiMC(circuit.HashSeeds[i], ecc.BLS12_381)
		if err != nil {
			return err
		}
		output = cs.Add(output, term)
		poutput = cs.Add(poutput, vterm)
		cs.AssertIsEqual(mimc.Hash(cs, output), circuit.OutputHashes[i])
	}
	cs.AssertIsEqual(output, circuit.Output)
	cs.AssertIsEqual(poutput, circuit.POutput)

	// prove the constant is valid (i.e. a string of bits separated by 7
	// zeros)


@@ 58,14 69,21 @@ func (circuit *EvalCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem
	}
	cs.AssertIsEqual(constructedConstant, circuit.Constant)

	mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381)
	return nil
}

func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
	var circuit EvalsCircuit
	circuit.HashSeeds = hashSeeds
	circuit.Inputs = make([]frontend.Variable, numVoters)
	circuit.OutputHashes = make([]frontend.Variable, numVoters)
	circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates)
	circuit.Coeffs = make([]frontend.Variable, numVoters-1)
	r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
	if err != nil {
		return err
		return nil, err
	}

	cs.AssertIsEqual(mimc.Hash(cs, circuit.POutput), circuit.POutputHash)

	return nil
	return r1cs, nil
}

type SumCircuit struct {


@@ 115,3 133,16 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem)

	return nil
}

func SumCircuitR1CS(hashSeeds []string, numVoters int) (frontend.CompiledConstraintSystem, error) {
	var circuit SumCircuit
	circuit.HashSeeds = hashSeeds
	circuit.HashSelects = make([]frontend.Variable, numVoters)
	circuit.OutputHashes = make([]frontend.Variable, numVoters)
	circuit.Outputs = make([]frontend.Variable, numVoters)
	r1cs, err := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &circuit)
	if err != nil {
		return nil, err
	}
	return r1cs, nil
}