~edwargix/tallyard

6c109fa5b52b060ecae8ebf8958be7ec2f6799bb — David Florness 2 years ago 1e20c7a
Fix voter JSON marshalling
4 files changed, 348 insertions(+), 23 deletions(-)

A election/marshal.go
A election/marshal_test.go
M election/msg.go
M election/voter.go
A election/marshal.go => election/marshal.go +67 -0
@@ 0,0 1,67 @@
package election

import (
	"bytes"
	"encoding/base64"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark/backend/groth16"
)

type MarshallableProvingKey struct {
	Pk groth16.ProvingKey
}

func (t *MarshallableProvingKey) UnmarshalJSON(b []byte) error {
	var buf bytes.Buffer
	_, err := buf.Write(bytes.Trim(b, "\""))
	if err != nil {
		return err
	}
	t.Pk = groth16.NewProvingKey(ecc.BLS12_381)
	_, err = t.Pk.ReadFrom(base64.NewDecoder(base64.StdEncoding, &buf))
	return err
}

func (t *MarshallableProvingKey) MarshalJSON() ([]byte, error) {
	var buf bytes.Buffer
	encoder := base64.NewEncoder(base64.StdEncoding, &buf)
	_, err := t.Pk.WriteTo(encoder)
	if err != nil {
		return nil, err
	}
	err = encoder.Close()
	if err != nil {
		return nil, err
	}
	return append([]byte{'"'}, append(buf.Bytes(), '"')...), nil
}

type MarshallableVerifyingKey struct {
	Vk groth16.VerifyingKey
}

func (t *MarshallableVerifyingKey) UnmarshalJSON(b []byte) error {
	var buf bytes.Buffer
	_, err := buf.Write(bytes.Trim(b, "\""))
	if err != nil {
		return err
	}
	t.Vk = groth16.NewVerifyingKey(ecc.BLS12_381)
	_, err = t.Vk.ReadFrom(base64.NewDecoder(base64.StdEncoding, &buf))
	return err
}

func (t *MarshallableVerifyingKey) MarshalJSON() ([]byte, error) {
	var buf bytes.Buffer
	encoder := base64.NewEncoder(base64.StdEncoding, &buf)
	_, err := t.Vk.WriteRawTo(encoder)
	if err != nil {
		return nil, err
	}
	err = encoder.Close()
	if err != nil {
		return nil, err
	}
	return append([]byte{'"'}, append(buf.Bytes(), '"')...), nil
}

A election/marshal_test.go => election/marshal_test.go +258 -0
@@ 0,0 1,258 @@
package election

import (
	"bytes"
	"crypto/rand"
	"encoding/base64"
	"encoding/json"
	"io"
	r "math/rand"
	"testing"
	"time"

	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
	"github.com/consensys/gnark/backend/groth16"
	"golang.org/x/crypto/nacl/box"
	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"
	"tallyard.xyz/math"
)

func randomLocalVoter(t *testing.T) *LocalVoter {
	// JoinIDIndex
	const numVoters = 10
	joinIDIndex := uint(r.Intn(numVoters))

	// Input(s)
	inputs := make([]*fr.Element, numVoters)
	for i := range inputs {
		var err error
		inputs[i], err = new(fr.Element).SetRandom()
		if err != nil {
			t.Fatal(err)
		}
	}
	input := inputs[joinIDIndex]
	inputBytes := input.Bytes()

	// PubKey / PrivKey
	pubKey, privKey, err := box.GenerateKey(rand.Reader)

	// HashSeed(s)
	hashSeeds := make([]string, numVoters)
	for i := range hashSeeds {
		var hashSeed [32]byte
		if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil {
			t.Fatal(err)
		}
		hashSeeds[i] = base64.StdEncoding.EncodeToString(hashSeed[:])
	}
	hashSeed := hashSeeds[joinIDIndex]

	// JoinEvt
	joinEvt := &event.Event{
		Sender:    "@test:example.org",
		Timestamp: time.Now().Unix(),
		Type:      JoinElectionMessage,
		ID:        "event1",
		RoomID:    "room1",
		Content:   event.Content{
			Parsed: JoinElectionContent{
				Version:  Version,
				CreateID: "createID",
				Input:    base64.StdEncoding.EncodeToString(inputBytes[:]),
				PubKey:   base64.StdEncoding.EncodeToString((*pubKey)[:]),
				HashSeed: hashSeed,
			},
		},
	}

	localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, hashSeed), privKey)
	localVoter.JoinIDIndex = &joinIDIndex

	// Output
	localVoter.Output, err = new(fr.Element).SetRandom()
	if err != nil {
		t.Fatal(err)
	}

	// OutputHashes
	outputHashes := make([][]byte, numVoters)
	for i := range outputHashes {
		var outputHash [32]byte
		if _, err := io.ReadFull(rand.Reader, outputHash[:]); err != nil {
			t.Fatal(err)
		}
		outputHashes[i] = outputHash[:]
	}
	localVoter.OutputHashes = &outputHashes

	// KeysID
	keysID := id.EventID("keysID")
	localVoter.KeysID = &keysID

	// EvalsID
	evalsID := id.EventID("evalsID")
	localVoter.EvalsID = &evalsID

	// Sum
	localVoter.Sum, err = new(fr.Element).SetRandom()
	if err != nil {
		t.Fatal(err)
	}

	// SumID
	sumID := id.EventID("sumID")
	localVoter.SumID = &sumID

	const numCandidates = 5

	{
		r1cs, err := math.EvalsCircuitR1CS(hashSeeds, numVoters, numCandidates)
		if err != nil {
			t.Fatal(err)
		}
		evalProvingKey, evalVerifyingKey, err := groth16.Setup(r1cs)
		if err != nil {
			t.Fatal(err)
		}

		// EvalProvingKey
		localVoter.EvalProvingKey   = &MarshallableProvingKey{evalProvingKey}
		// EvalVerifyingKey
		localVoter.EvalVerifyingKey = &MarshallableVerifyingKey{evalVerifyingKey}
	}

	{
		r1cs, err := math.SumCircuitR1CS(hashSeeds, numVoters)
		if err != nil {
			t.Fatal(err)
		}
		sumProvingKey, sumVerifyingKey, err := groth16.Setup(r1cs)
		if err != nil {
			t.Fatal(err)
		}
		// SumProvingKey
		localVoter.SumProvingKey    = &MarshallableProvingKey{sumProvingKey}
		// SumVerifyingKey
		localVoter.SumVerifyingKey  = &MarshallableVerifyingKey{sumVerifyingKey}
	}

	ballot := make([][]byte, numCandidates)
	for i := range ballot {
		ballot[i] = make([]byte, numCandidates)
		for j := range ballot[i] {
			ballot[i][j] = byte(r.Intn(2))
		}
	}

	// Poly
	localVoter.Poly = math.NewRandomPoly(ballot, hashSeeds, inputs)

	return localVoter
}

func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) {
	// Input
	if !lv1.Input.Equal(&lv2.Input) {
		t.Errorf("inputs not equal: %s vs %s", &lv1.Input, &lv2.Input)
	}

	// JoinEvt
	if lv2.JoinEvt.Sender != lv1.JoinEvt.Sender {
		t.Error("join event senders not equal")
	}
	if lv2.JoinEvt.Timestamp != lv1.JoinEvt.Timestamp {
		t.Error("join event timestamps not equal")
	}
	// we don't yet parse the type right away
	// if unmarshalledLocalVoter.JoinEvt.Type != localVoter.JoinEvt.Type {
	// 	t.Error("join event types not equal")
	// }
	if lv2.JoinEvt.ID != lv1.JoinEvt.ID {
		t.Error("join event IDs not equal")
	}
	if lv2.JoinEvt.RoomID != lv1.JoinEvt.RoomID {
		t.Error("join event RoomIDs not equal")
	}
	// TODO: JoinEvt.Content

	// PubKey
	if bytes.Compare(lv1.PubKey[:], lv2.PubKey[:]) != 0 {
		t.Errorf("pubkeys not equal")
	}

	// HashSeed
	if lv1.HashSeed != lv2.HashSeed {
		t.Errorf("hash seeds not equal")
	}

	// JoinIDIndex
	if *lv2.JoinIDIndex != *lv1.JoinIDIndex {
		t.Error("join ID indices not equal")
	}

	// Output
	if !lv2.Output.Equal(lv1.Output) {
		t.Error("outputs not equal")
	}

	// OutputHashes
	for i, outputHash := range *lv2.OutputHashes {
		if !bytes.Equal(outputHash, (*lv1.OutputHashes)[i]) {
			t.Error("output hashes not equal")
		}
	}

	// KeysID
	if *lv2.KeysID != *lv1.KeysID {
		t.Error("keys IDs not equal")
	}
	// EvalsID
	if *lv2.EvalsID != *lv1.EvalsID {
		t.Error("evals IDs not equal")
	}
	// Sum
	if !lv2.Sum.Equal(lv1.Sum) {
		t.Error("sums not equal")
	}
	// SumID
	if *lv2.SumID != *lv1.SumID {
		t.Error("sum IDs not equal")
	}

	// EvalProvingKey
	if lv2.EvalProvingKey.Pk.IsDifferent(lv1.EvalProvingKey.Pk) {
		t.Error("eval proving keys not equal")
	}
	// EvalVerifyingKey
	if lv2.EvalVerifyingKey.Vk.IsDifferent(lv1.EvalVerifyingKey.Vk) {
		t.Error("eval verifying keys not equal")
	}
	// SumProvingKey
	if lv2.SumProvingKey.Pk.IsDifferent(lv1.SumProvingKey.Pk) {
		t.Error("sum proving keys not equal")
	}
	// SumVerifyingKey
	if lv2.SumVerifyingKey.Vk.IsDifferent(lv1.SumVerifyingKey.Vk) {
		t.Error("sum verifying keys not equal")
	}
}

func TestLocalVoterMarshal(t *testing.T) {
	localVoter := randomLocalVoter(t)

	b, err := json.Marshal(localVoter)
	if err != nil {
		t.Fatal(err)
	}

	unmarshalledLocalVoter := new(LocalVoter)

	err = json.Unmarshal(b, &unmarshalledLocalVoter)
	if err != nil {
		t.Fatal(err)
	}

	ensureEqual(t, localVoter, unmarshalledLocalVoter)
}

M election/msg.go => election/msg.go +9 -9
@@ 281,12 281,12 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success 
		warnf("the hash seed is empty")
		return
	}
	hashSeed, err := base64.StdEncoding.DecodeString(content.HashSeed)
	hashSeedBytes, err := base64.StdEncoding.DecodeString(content.HashSeed)
	if err != nil {
		warnf("we couldn't decode their hash seed: %s", err)
		return
	}
	if len(hashSeed) < 32 {
	if len(hashSeedBytes) < 32 {
		warnf("their hash seed is fewer than 32 bytes")
		return
	}


@@ 325,7 325,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success 
	defer el.Save()
	defer el.Unlock()

	el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, hashSeed)
	el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, content.HashSeed)

	return true
}


@@ 389,7 389,7 @@ func (elections *ElectionsMap) onStartElectionMessage(evt *event.Event) (success
	el.FinalJoinIDs = &content.JoinIDs
	for i, joinID := range *el.FinalJoinIDs {
		// TODO: gross
		joinIndex := i
		var joinIndex uint = uint(i)
		el.Joins[joinID].JoinIDIndex = &joinIndex
	}



@@ 500,8 500,8 @@ func (elections *ElectionsMap) onKeysMessage(evt *event.Event, client *mautrix.C
	defer el.Save()
	defer el.Unlock()

	voter.EvalProvingKey = &evalProvingKey
	voter.SumProvingKey = &sumProvingKey
	voter.EvalProvingKey = &MarshallableProvingKey{evalProvingKey}
	voter.SumProvingKey = &MarshallableProvingKey{sumProvingKey}
	voter.KeysID = &evt.ID

	return true


@@ 670,7 670,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
			errorf("our evals verifying key is nil")
			return
		}
		err := groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit)
		err := groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk, &publicCircuit)
		if err != nil {
			warnf("poly eval proof verification failed: %s", err)
			return


@@ 827,7 827,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
		for i, joinID := range *el.FinalJoinIDs {
			evaler := el.Joins[joinID]
			hashSeeds[i] = evaler.HashSeed
			if i == voterJoinIDIndex {
			if uint(i) == voterJoinIDIndex {
				hashSelects[i].Assign(1)
			} else {
				hashSelects[i].Assign(0)


@@ 846,7 846,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
			errorf("our sum verifying key is nil")
			return
		}
		err := groth16.Verify(proof, *el.LocalVoter.SumVerifyingKey, &publicCircuit)
		err := groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk, &publicCircuit)
		if err != nil {
			warnf("sum proof verification failed: %s", err)
			return

M election/voter.go => election/voter.go +14 -14
@@ 24,7 24,7 @@ type Voter struct {
	PubKey   [32]byte    `json:"pub_key"`
	HashSeed string      `json:"hash_seed"`

	JoinIDIndex  *int        `json:"join_id_index,omitempty"`
	JoinIDIndex  *uint       `json:"join_id_index,omitempty"`
	Output       *fr.Element `json:"output,omitempty"`
	OutputHashes *[][]byte   `json:"output_hashes,omitempty"`
	KeysID       *id.EventID `json:"keys_id,omitempty"`


@@ 32,26 32,26 @@ type Voter struct {
	Sum          *fr.Element `json:"sum,omitempty"`
	SumID        *id.EventID `json:"sum_id,omitempty"`

	EvalProvingKey *groth16.ProvingKey `json:"eval_proving_key"`
	SumProvingKey  *groth16.ProvingKey `json:"sum_proving_key"`
	EvalProvingKey *MarshallableProvingKey `json:"eval_proving_key,omitempty"`
	SumProvingKey  *MarshallableProvingKey `json:"sum_proving_key,omitempty"`
}

type LocalVoter struct {
	*Voter

	PrivKey          [32]byte             `json:"priv_key"`
	PrivKey [32]byte `json:"priv_key"`

	EvalVerifyingKey *groth16.VerifyingKey `json:"eval_verifying_key"`
	Poly             *math.Poly            `json:"poly,omitempty"`
	SumVerifyingKey  *groth16.VerifyingKey `json:"sum_verifying_key"`
	EvalVerifyingKey *MarshallableVerifyingKey `json:"eval_verifying_key,omitempty"`
	Poly             *math.Poly                `json:"poly,omitempty"`
	SumVerifyingKey  *MarshallableVerifyingKey `json:"sum_verifying_key,omitempty"`
}

func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed []byte) *Voter {
func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed string) *Voter {
	return &Voter{
		Input:    *new(fr.Element).Set(input),
		JoinEvt:  *joinEvt,
		PubKey:   *pubKey,
		HashSeed: string(hashSeed),
		HashSeed: hashSeed,
	}
}



@@ 251,8 251,8 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
	defer el.Save()
	defer el.Unlock()

	el.LocalVoter.EvalVerifyingKey = &evalVerifyingKey
	el.LocalVoter.SumVerifyingKey = &sumVerifyingKey
	el.LocalVoter.EvalVerifyingKey = &MarshallableVerifyingKey{evalVerifyingKey}
	el.LocalVoter.SumVerifyingKey = &MarshallableVerifyingKey{sumVerifyingKey}

	return nil
}


@@ 269,7 269,7 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
	for i, joinID := range *el.FinalJoinIDs {
		voter := el.Joins[joinID]

		output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey)
		output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, voter.EvalProvingKey.Pk)

		// encrypt output
		var outputBytes [32]byte


@@ 362,7 362,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro

		ourJoinIDIndex := *el.LocalVoter.JoinIDIndex
		for i, joinID := range *el.FinalJoinIDs {
			if i == ourJoinIDIndex {
			if uint(i) == ourJoinIDIndex {
				hashSelects[i].Assign(1)
			} else {
				hashSelects[i].Assign(0)


@@ 375,7 375,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro

		for i, joinID := range *el.FinalJoinIDs {
			voter := *el.Joins[joinID]
			proof, err := groth16.Prove(r1cs, *voter.SumProvingKey, &witness)
			proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk, &witness)
			if err != nil {
				el.RUnlock()
				return fmt.Errorf("couldn't prove sum: %s", err)