~edwargix/tallyard

d6670588a8acc712ba0aa2637464fc40af73dec2 — David Florness 3 years ago 48cb6f2
zk: derive one hash seed by combining each "part" sent by voters

This slightly simplifies the circuits and protects against voters picking
advantageous seeds [0].

[0]: Hopefully.  This seems evident to me, but I'm not a cryptographer :)
M cmd/tallyard/main.go => cmd/tallyard/main.go +1 -1
@@ 128,7 128,7 @@ func main() {
		for i, joinID := range *el.FinalJoinIDs {
			inputs[i] = &el.Joins[joinID].Input
		}
		el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeeds(), inputs)
		el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeed(), inputs)
		el.Save()
	}


M election/election.go => election/election.go +10 -8
@@ 7,6 7,7 @@ import (

	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"
	"tallyard.xyz/math"
)

type Election struct {


@@ 25,7 26,7 @@ type Election struct {
	// Save election to disk.  Set by the containing ElectionsMap at runtime
	Save         func()                `json:"-"`

	hashSeeds    *[]string             `json:"-"`
	hashSeed     *string               `json:"-"`
}

func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election {


@@ 56,19 57,20 @@ func (el *Election) UnmarshalJSON(b []byte) error {
	return nil
}

func (el *Election) GetHashSeeds() []string {
	if el.hashSeeds != nil {
		return *el.hashSeeds
func (el *Election) GetHashSeed() string {
	if el.hashSeed != nil {
		return *el.hashSeed
	}

	if el.FinalJoinIDs == nil {
		panic("GatHashSeeds called before election started")
	}

	hashSeeds := make([]string, len(*el.FinalJoinIDs))
	seedParts := make([][]byte, len(*el.FinalJoinIDs))
	for i, joinID := range *el.FinalJoinIDs {
		hashSeeds[i] = el.Joins[joinID].HashSeed
		seedParts[i] = el.Joins[joinID].SeedPart
	}
	el.hashSeeds = &hashSeeds
	return hashSeeds
	hashSeed := math.HashSeedFromSeedParts(seedParts)
	el.hashSeed = &hashSeed
	return hashSeed
}

M election/marshal_test.go => election/marshal_test.go +15 -12
@@ 40,15 40,16 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	pubKey, privKey, err := box.GenerateKey(rand.Reader)

	// HashSeed(s)
	hashSeeds := make([]string, numVoters)
	for i := range hashSeeds {
		var hashSeed [32]byte
		if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil {
	seedParts := make([][]byte, numVoters)
	for i := range seedParts {
		var seedPart [32]byte
		if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil {
			t.Fatal(err)
		}
		hashSeeds[i] = base64.StdEncoding.EncodeToString(hashSeed[:])
		seedParts[i] = seedPart[:]
	}
	hashSeed := hashSeeds[joinIDIndex]

	localHashSeed := seedParts[joinIDIndex]

	// JoinEvt
	joinEvt := &event.Event{


@@ 63,14 64,16 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
				CreateID: "createID",
				Input:    base64.StdEncoding.EncodeToString(inputBytes[:]),
				PubKey:   base64.StdEncoding.EncodeToString((*pubKey)[:]),
				HashSeed: hashSeed,
				SeedPart: base64.StdEncoding.EncodeToString(localHashSeed),
			},
		},
	}

	localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, hashSeed), privKey)
	localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, localHashSeed), privKey)
	localVoter.JoinIDIndex = &joinIDIndex

	hashSeed := math.HashSeedFromSeedParts(seedParts)

	// Output
	localVoter.Output, err = new(fr.Element).SetRandom()
	if err != nil {


@@ 109,7 112,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	const numCandidates = 5

	{
		r1cs, err := math.EvalsCircuitR1CS(hashSeeds, numVoters, numCandidates)
		r1cs, err := math.EvalsCircuitR1CS(hashSeed, numVoters, numCandidates)
		if err != nil {
			t.Fatal(err)
		}


@@ 125,7 128,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	}

	{
		r1cs, err := math.SumCircuitR1CS(hashSeeds, numVoters)
		r1cs, err := math.SumCircuitR1CS(hashSeed, numVoters)
		if err != nil {
			t.Fatal(err)
		}


@@ 148,7 151,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	}

	// Poly
	localVoter.Poly = math.NewRandomPoly(ballot, hashSeeds, inputs)
	localVoter.Poly = math.NewRandomPoly(ballot, hashSeed, inputs)

	return localVoter
}


@@ 184,7 187,7 @@ func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) {
	}

	// HashSeed
	if lv1.HashSeed != lv2.HashSeed {
	if bytes.Compare(lv1.SeedPart, lv2.SeedPart) != 0 {
		t.Errorf("hash seeds not equal")
	}


M election/msg.go => election/msg.go +12 -12
@@ 69,7 69,7 @@ type JoinElectionContent struct {
	CreateID id.EventID `json:"create_id"`
	Input    string     `json:"input"`
	PubKey   string     `json:"pub_key"`
	HashSeed string     `json:"hash_seed"`
	SeedPart string     `json:"seed_part"`
}

type StartElectionContent struct {


@@ 281,18 281,18 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success 
		return
	}

	// HashSeed
	if content.HashSeed == "" {
		warnf("the hash seed is empty")
	// SeedPart
	if content.SeedPart == "" {
		warnf("the seed part is empty")
		return
	}
	hashSeedBytes, err := base64.StdEncoding.DecodeString(content.HashSeed)
	seedPart, err := base64.StdEncoding.DecodeString(content.SeedPart)
	if err != nil {
		warnf("we couldn't decode their hash seed: %s", err)
		warnf("we couldn't decode their seed part: %s", err)
		return
	}
	if len(hashSeedBytes) < 32 {
		warnf("their hash seed is fewer than 32 bytes")
	if len(seedPart) < 32 {
		warnf("their seed part is fewer than 32 bytes")
		return
	}



@@ 330,7 330,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success 
	defer el.Save()
	defer el.Unlock()

	el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, content.HashSeed)
	el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, seedPart)

	return true
}


@@ 665,7 665,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
		var publicCircuit math.EvalsCircuit

		// public
		publicCircuit.HashSeeds = el.GetHashSeeds()
		publicCircuit.HashSeed = el.GetHashSeed()
		publicCircuit.Inputs = make([]frontend.Variable, len(*el.FinalJoinIDs))
		for i, joinID := range *el.FinalJoinIDs {
			publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input)


@@ 695,7 695,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
	// ensure our output hashes to the given hash
	{
		outputBytes := output.Bytes()
		outputHash, err := mimc.Sum(el.LocalVoter.HashSeed, outputBytes[:])
		outputHash, err := mimc.Sum(el.GetHashSeed(), outputBytes[:])
		if err != nil {
			errorf("couldn't hash output: %s", err)
			return


@@ 847,7 847,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
			}
			outputHashes[i].Assign((*evaler.OutputHashes)[voterJoinIDIndex])
		}
		publicCircuit.HashSeeds = el.GetHashSeeds()
		publicCircuit.HashSeed = el.GetHashSeed()
		publicCircuit.HashSelects = hashSelects
		publicCircuit.OutputHashes = outputHashes
		publicCircuit.Sum.Assign(sum)

M election/voter.go => election/voter.go +8 -8
@@ 23,7 23,7 @@ type Voter struct {
	Input    fr.Element  `json:"input"`
	JoinEvt  event.Event `json:"join_evt"`
	PubKey   [32]byte    `json:"pub_key"`
	HashSeed string      `json:"hash_seed"`
	SeedPart []byte      `json:"seed_part"`

	JoinIDIndex  *uint       `json:"join_id_index,omitempty"`
	Output       *fr.Element `json:"output,omitempty"`


@@ 47,12 47,12 @@ type LocalVoter struct {
	SumVerifyingKey  *MarshallableVerifyingKey `json:"sum_verifying_key,omitempty"`
}

func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed string) *Voter {
func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, seedPart []byte) *Voter {
	return &Voter{
		Input:    *new(fr.Element).Set(input),
		JoinEvt:  *joinEvt,
		PubKey:   *pubKey,
		HashSeed: hashSeed,
		SeedPart: seedPart,
	}
}



@@ 104,7 104,7 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore)
		CreateID: el.CreateEvt.ID,
		Input:    base64.StdEncoding.EncodeToString(inputBytes[:]),
		PubKey:   base64.StdEncoding.EncodeToString((*pubKey)[:]),
		HashSeed: base64.StdEncoding.EncodeToString(seedBytes[:]),
		SeedPart: base64.StdEncoding.EncodeToString(seedBytes[:]),
	})
	if err != nil {
		return fmt.Errorf("couldn't send join messages: %s", err)


@@ 171,7 171,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
		evalVerifyingKey  groth16.VerifyingKey
	)
	{
		r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs), len(el.Candidates))
		r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs), len(el.Candidates))
		if err != nil {
			return fmt.Errorf("couldn't compile eval circuit: %s", err)
		}


@@ 200,7 200,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
		sumVerifyingKey  groth16.VerifyingKey
	)
	{
		r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))
		r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))
		if err != nil {
			return fmt.Errorf("couldn't compile sum circuit: %s", err)
		}


@@ 346,7 346,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
		evalsIDs[i] = *voter.EvalsID
	}

	r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))
	r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))

	proofs := make([]string, numVoters)
	{


@@ 355,7 355,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
		outputs := make([]frontend.Variable, numVoters)

		var witness math.SumCircuit
		witness.HashSeeds = el.GetHashSeeds()
		witness.HashSeed = el.GetHashSeed()
		witness.HashSelects = hashSelects
		witness.OutputHashes = outputHashes
		witness.Sum.Assign(el.LocalVoter.Sum)

A math/hash.go => math/hash.go +11 -0
@@ 0,0 1,11 @@
package math

import "golang.org/x/crypto/sha3"

func HashSeedFromSeedParts(seedParts [][]byte) string {
	res := sha3.Sum256(seedParts[0])
	for _, part := range seedParts[1:] {
		res = sha3.Sum256(append(res[:], part...))
	}
	return string(res[:])
}

M math/lagrange_test.go => math/lagrange_test.go +6 -7
@@ 2,7 2,6 @@ package math

import (
	"crypto/rand"
	"encoding/base64"
	"io"
	"testing"



@@ 66,16 65,16 @@ func TestRandomPoly(t *testing.T) {
		}

		// hash seeds don't matter
		hashSeeds := make([]string, n)
		for i := range hashSeeds {
			var hashSeed [32]byte
			if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil {
		seedParts := make([][]byte, n)
		for i := range seedParts {
			var seedPart [32]byte
			if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil {
				t.Fatal(err)
			}
			hashSeeds[i] = base64.RawStdEncoding.EncodeToString(hashSeed[:])
			seedParts[i] = seedPart[:]
		}

		poly := NewRandomPoly(ballot, hashSeeds, inputs)
		poly := NewRandomPoly(ballot, HashSeedFromSeedParts(seedParts), inputs)
		points := make([]Point, n)
		for i := 0; i < n; i++ {
		genx:

M math/poly.go => math/poly.go +10 -10
@@ 10,10 10,10 @@ import (
)

type Poly struct {
	Ballot    [][]byte      `json:"ballot"`
	Coeffs    []*fr.Element `json:"coeffs"`
	HashSeeds []string      `json:"hash_seeds"`
	Inputs    []*fr.Element `json:"inputs"`
	Ballot   [][]byte      `json:"ballot"`
	Coeffs   []*fr.Element `json:"coeffs"`
	HashSeed string        `json:"hash_seed"`
	Inputs   []*fr.Element `json:"inputs"`

	cache *polyCache `json:"-"`
}


@@ 33,7 33,7 @@ type outputCache struct {

// NewRandomPoly generates a random polynomial of degree numVoters-1, using
// `ballot' as the constant term
func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *Poly {
func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly {
	coeffs := make([]*fr.Element, len(inputs) - 1)

	var err error


@@ 44,7 44,7 @@ func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *P
		}
	}

	return &Poly{ballot, coeffs, hashSeeds, inputs, nil}
	return &Poly{ballot, coeffs, hashSeed, inputs, nil}
}

func (p *Poly) setupCache() {


@@ 72,10 72,10 @@ func (p *Poly) setupCache() {

	// calculate outputs and their hashes
	p.cache.outputHashes = make(map[*fr.Element]outputCache)
	for i, input := range p.Inputs {
	for _, input := range p.Inputs {
		output := p.eval(input)
		byts := output.Bytes()
		hash, err := mimc.Sum(p.HashSeeds[i], byts[:])
		hash, err := mimc.Sum(p.HashSeed, byts[:])
		if err != nil {
			panic(err)
		}


@@ 85,7 85,7 @@ func (p *Poly) setupCache() {
	// compile R1CS
	{
		var err error
		p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeeds, len(p.Inputs), len(p.Ballot))
		p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeed, len(p.Inputs), len(p.Ballot))
		if err != nil {
			panic(err)
		}


@@ 95,7 95,7 @@ func (p *Poly) setupCache() {
	{
		witness := &p.cache.witness
		// public
		witness.HashSeeds = p.HashSeeds
		witness.HashSeed = p.HashSeed
		witness.Inputs = make([]frontend.Variable, len(p.Inputs))
		witness.OutputHashes = make([]frontend.Variable, len(p.Inputs))
		for i, input := range p.Inputs {

M math/zk.go => math/zk.go +13 -19
@@ 10,7 10,7 @@ import (
)

type EvalsCircuit struct {
	HashSeeds    []string
	HashSeed     string
	Inputs       []frontend.Variable `gnark:",public"`
	OutputHashes []frontend.Variable `gnark:",public"`



@@ 22,9 22,6 @@ type EvalsCircuit struct {
}

func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
	if len(circuit.HashSeeds) != len(circuit.Inputs) {
		return errors.New("len(circuit.HashSeeds) != len(circuit.Inputs)")
	}
	if len(circuit.Inputs) != len(circuit.OutputHashes) {
		return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)")
	}


@@ 47,7 44,7 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSyste
		}

		// hash
		mimc, err := mimc.NewMiMC(circuit.HashSeeds[i], ecc.BLS12_381)
		mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381)
		if err != nil {
			return err
		}


@@ 72,9 69,9 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSyste
	return nil
}

func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
	var circuit EvalsCircuit
	circuit.HashSeeds = hashSeeds
	circuit.HashSeed = hashSeed
	circuit.Inputs = make([]frontend.Variable, numVoters)
	circuit.OutputHashes = make([]frontend.Variable, numVoters)
	circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates)


@@ 87,7 84,7 @@ func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (fro
}

type SumCircuit struct {
	HashSeeds    []string
	HashSeed     string
	HashSelects  []frontend.Variable `gnark:",public"`
	OutputHashes []frontend.Variable `gnark:",public"`
	Sum          frontend.Variable   `gnark:",public"`


@@ 96,9 93,6 @@ type SumCircuit struct {
}

func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
	if len(circuit.HashSeeds) != len(circuit.HashSelects) {
		return errors.New("len(circuit.HashSeeds) != len(circuit.HashSelects)")
	}
	if len(circuit.HashSelects) != len(circuit.OutputHashes) {
		return errors.New("len(circuit.HashSelects) != len(circuit.OutputHashes)")
	}


@@ 108,13 102,13 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem)

	// prove the secret outputs hash to public hashes
	zero := cs.Constant(0)

	mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381)
	if err != nil {
		return err
	}
	// TODO: this seems verify inefficient
	for i, hashSeed := range circuit.HashSeeds {
		mimc, err := mimc.NewMiMC(hashSeed, ecc.BLS12_381)
		if err != nil {
			return err
		}
		bit := circuit.HashSelects[i]
	for _, bit := range circuit.HashSelects {
		for j, output := range circuit.Outputs {
			cs.AssertIsEqual(
				cs.Select(bit, mimc.Hash(cs, output), zero),


@@ 134,9 128,9 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem)
	return nil
}

func SumCircuitR1CS(hashSeeds []string, numVoters int) (frontend.CompiledConstraintSystem, error) {
func SumCircuitR1CS(hashSeed string, numVoters int) (frontend.CompiledConstraintSystem, error) {
	var circuit SumCircuit
	circuit.HashSeeds = hashSeeds
	circuit.HashSeed = hashSeed
	circuit.HashSelects = make([]frontend.Variable, numVoters)
	circuit.OutputHashes = make([]frontend.Variable, numVoters)
	circuit.Outputs = make([]frontend.Variable, numVoters)