A election/marshal.go => election/marshal.go +67 -0
@@ 0,0 1,67 @@
+package election
+
+import (
+ "bytes"
+ "encoding/base64"
+
+ "github.com/consensys/gnark-crypto/ecc"
+ "github.com/consensys/gnark/backend/groth16"
+)
+
+type MarshallableProvingKey struct {
+ Pk groth16.ProvingKey
+}
+
+func (t *MarshallableProvingKey) UnmarshalJSON(b []byte) error {
+ var buf bytes.Buffer
+ _, err := buf.Write(bytes.Trim(b, "\""))
+ if err != nil {
+ return err
+ }
+ t.Pk = groth16.NewProvingKey(ecc.BLS12_381)
+ _, err = t.Pk.ReadFrom(base64.NewDecoder(base64.StdEncoding, &buf))
+ return err
+}
+
+func (t *MarshallableProvingKey) MarshalJSON() ([]byte, error) {
+ var buf bytes.Buffer
+ encoder := base64.NewEncoder(base64.StdEncoding, &buf)
+ _, err := t.Pk.WriteTo(encoder)
+ if err != nil {
+ return nil, err
+ }
+ err = encoder.Close()
+ if err != nil {
+ return nil, err
+ }
+ return append([]byte{'"'}, append(buf.Bytes(), '"')...), nil
+}
+
+type MarshallableVerifyingKey struct {
+ Vk groth16.VerifyingKey
+}
+
+func (t *MarshallableVerifyingKey) UnmarshalJSON(b []byte) error {
+ var buf bytes.Buffer
+ _, err := buf.Write(bytes.Trim(b, "\""))
+ if err != nil {
+ return err
+ }
+ t.Vk = groth16.NewVerifyingKey(ecc.BLS12_381)
+ _, err = t.Vk.ReadFrom(base64.NewDecoder(base64.StdEncoding, &buf))
+ return err
+}
+
+func (t *MarshallableVerifyingKey) MarshalJSON() ([]byte, error) {
+ var buf bytes.Buffer
+ encoder := base64.NewEncoder(base64.StdEncoding, &buf)
+ _, err := t.Vk.WriteRawTo(encoder)
+ if err != nil {
+ return nil, err
+ }
+ err = encoder.Close()
+ if err != nil {
+ return nil, err
+ }
+ return append([]byte{'"'}, append(buf.Bytes(), '"')...), nil
+}
A election/marshal_test.go => election/marshal_test.go +258 -0
@@ 0,0 1,258 @@
+package election
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "io"
+ r "math/rand"
+ "testing"
+ "time"
+
+ "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
+ "github.com/consensys/gnark/backend/groth16"
+ "golang.org/x/crypto/nacl/box"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+ "tallyard.xyz/math"
+)
+
+func randomLocalVoter(t *testing.T) *LocalVoter {
+ // JoinIDIndex
+ const numVoters = 10
+ joinIDIndex := uint(r.Intn(numVoters))
+
+ // Input(s)
+ inputs := make([]*fr.Element, numVoters)
+ for i := range inputs {
+ var err error
+ inputs[i], err = new(fr.Element).SetRandom()
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ input := inputs[joinIDIndex]
+ inputBytes := input.Bytes()
+
+ // PubKey / PrivKey
+ pubKey, privKey, err := box.GenerateKey(rand.Reader)
+
+ // HashSeed(s)
+ hashSeeds := make([]string, numVoters)
+ for i := range hashSeeds {
+ var hashSeed [32]byte
+ if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil {
+ t.Fatal(err)
+ }
+ hashSeeds[i] = base64.StdEncoding.EncodeToString(hashSeed[:])
+ }
+ hashSeed := hashSeeds[joinIDIndex]
+
+ // JoinEvt
+ joinEvt := &event.Event{
+ Sender: "@test:example.org",
+ Timestamp: time.Now().Unix(),
+ Type: JoinElectionMessage,
+ ID: "event1",
+ RoomID: "room1",
+ Content: event.Content{
+ Parsed: JoinElectionContent{
+ Version: Version,
+ CreateID: "createID",
+ Input: base64.StdEncoding.EncodeToString(inputBytes[:]),
+ PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]),
+ HashSeed: hashSeed,
+ },
+ },
+ }
+
+ localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, hashSeed), privKey)
+ localVoter.JoinIDIndex = &joinIDIndex
+
+ // Output
+ localVoter.Output, err = new(fr.Element).SetRandom()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // OutputHashes
+ outputHashes := make([][]byte, numVoters)
+ for i := range outputHashes {
+ var outputHash [32]byte
+ if _, err := io.ReadFull(rand.Reader, outputHash[:]); err != nil {
+ t.Fatal(err)
+ }
+ outputHashes[i] = outputHash[:]
+ }
+ localVoter.OutputHashes = &outputHashes
+
+ // KeysID
+ keysID := id.EventID("keysID")
+ localVoter.KeysID = &keysID
+
+ // EvalsID
+ evalsID := id.EventID("evalsID")
+ localVoter.EvalsID = &evalsID
+
+ // Sum
+ localVoter.Sum, err = new(fr.Element).SetRandom()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // SumID
+ sumID := id.EventID("sumID")
+ localVoter.SumID = &sumID
+
+ const numCandidates = 5
+
+ {
+ r1cs, err := math.EvalsCircuitR1CS(hashSeeds, numVoters, numCandidates)
+ if err != nil {
+ t.Fatal(err)
+ }
+ evalProvingKey, evalVerifyingKey, err := groth16.Setup(r1cs)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // EvalProvingKey
+ localVoter.EvalProvingKey = &MarshallableProvingKey{evalProvingKey}
+ // EvalVerifyingKey
+ localVoter.EvalVerifyingKey = &MarshallableVerifyingKey{evalVerifyingKey}
+ }
+
+ {
+ r1cs, err := math.SumCircuitR1CS(hashSeeds, numVoters)
+ if err != nil {
+ t.Fatal(err)
+ }
+ sumProvingKey, sumVerifyingKey, err := groth16.Setup(r1cs)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // SumProvingKey
+ localVoter.SumProvingKey = &MarshallableProvingKey{sumProvingKey}
+ // SumVerifyingKey
+ localVoter.SumVerifyingKey = &MarshallableVerifyingKey{sumVerifyingKey}
+ }
+
+ ballot := make([][]byte, numCandidates)
+ for i := range ballot {
+ ballot[i] = make([]byte, numCandidates)
+ for j := range ballot[i] {
+ ballot[i][j] = byte(r.Intn(2))
+ }
+ }
+
+ // Poly
+ localVoter.Poly = math.NewRandomPoly(ballot, hashSeeds, inputs)
+
+ return localVoter
+}
+
+func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) {
+ // Input
+ if !lv1.Input.Equal(&lv2.Input) {
+ t.Errorf("inputs not equal: %s vs %s", &lv1.Input, &lv2.Input)
+ }
+
+ // JoinEvt
+ if lv2.JoinEvt.Sender != lv1.JoinEvt.Sender {
+ t.Error("join event senders not equal")
+ }
+ if lv2.JoinEvt.Timestamp != lv1.JoinEvt.Timestamp {
+ t.Error("join event timestamps not equal")
+ }
+ // we don't yet parse the type right away
+ // if unmarshalledLocalVoter.JoinEvt.Type != localVoter.JoinEvt.Type {
+ // t.Error("join event types not equal")
+ // }
+ if lv2.JoinEvt.ID != lv1.JoinEvt.ID {
+ t.Error("join event IDs not equal")
+ }
+ if lv2.JoinEvt.RoomID != lv1.JoinEvt.RoomID {
+ t.Error("join event RoomIDs not equal")
+ }
+ // TODO: JoinEvt.Content
+
+ // PubKey
+ if bytes.Compare(lv1.PubKey[:], lv2.PubKey[:]) != 0 {
+ t.Errorf("pubkeys not equal")
+ }
+
+ // HashSeed
+ if lv1.HashSeed != lv2.HashSeed {
+ t.Errorf("hash seeds not equal")
+ }
+
+ // JoinIDIndex
+ if *lv2.JoinIDIndex != *lv1.JoinIDIndex {
+ t.Error("join ID indices not equal")
+ }
+
+ // Output
+ if !lv2.Output.Equal(lv1.Output) {
+ t.Error("outputs not equal")
+ }
+
+ // OutputHashes
+ for i, outputHash := range *lv2.OutputHashes {
+ if !bytes.Equal(outputHash, (*lv1.OutputHashes)[i]) {
+ t.Error("output hashes not equal")
+ }
+ }
+
+ // KeysID
+ if *lv2.KeysID != *lv1.KeysID {
+ t.Error("keys IDs not equal")
+ }
+ // EvalsID
+ if *lv2.EvalsID != *lv1.EvalsID {
+ t.Error("evals IDs not equal")
+ }
+ // Sum
+ if !lv2.Sum.Equal(lv1.Sum) {
+ t.Error("sums not equal")
+ }
+ // SumID
+ if *lv2.SumID != *lv1.SumID {
+ t.Error("sum IDs not equal")
+ }
+
+ // EvalProvingKey
+ if lv2.EvalProvingKey.Pk.IsDifferent(lv1.EvalProvingKey.Pk) {
+ t.Error("eval proving keys not equal")
+ }
+ // EvalVerifyingKey
+ if lv2.EvalVerifyingKey.Vk.IsDifferent(lv1.EvalVerifyingKey.Vk) {
+ t.Error("eval verifying keys not equal")
+ }
+ // SumProvingKey
+ if lv2.SumProvingKey.Pk.IsDifferent(lv1.SumProvingKey.Pk) {
+ t.Error("sum proving keys not equal")
+ }
+ // SumVerifyingKey
+ if lv2.SumVerifyingKey.Vk.IsDifferent(lv1.SumVerifyingKey.Vk) {
+ t.Error("sum verifying keys not equal")
+ }
+}
+
+func TestLocalVoterMarshal(t *testing.T) {
+ localVoter := randomLocalVoter(t)
+
+ b, err := json.Marshal(localVoter)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ unmarshalledLocalVoter := new(LocalVoter)
+
+ err = json.Unmarshal(b, &unmarshalledLocalVoter)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ensureEqual(t, localVoter, unmarshalledLocalVoter)
+}
M election/msg.go => election/msg.go +9 -9
@@ 281,12 281,12 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success
warnf("the hash seed is empty")
return
}
- hashSeed, err := base64.StdEncoding.DecodeString(content.HashSeed)
+ hashSeedBytes, err := base64.StdEncoding.DecodeString(content.HashSeed)
if err != nil {
warnf("we couldn't decode their hash seed: %s", err)
return
}
- if len(hashSeed) < 32 {
+ if len(hashSeedBytes) < 32 {
warnf("their hash seed is fewer than 32 bytes")
return
}
@@ 325,7 325,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success
defer el.Save()
defer el.Unlock()
- el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, hashSeed)
+ el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, content.HashSeed)
return true
}
@@ 389,7 389,7 @@ func (elections *ElectionsMap) onStartElectionMessage(evt *event.Event) (success
el.FinalJoinIDs = &content.JoinIDs
for i, joinID := range *el.FinalJoinIDs {
// TODO: gross
- joinIndex := i
+ var joinIndex uint = uint(i)
el.Joins[joinID].JoinIDIndex = &joinIndex
}
@@ 500,8 500,8 @@ func (elections *ElectionsMap) onKeysMessage(evt *event.Event, client *mautrix.C
defer el.Save()
defer el.Unlock()
- voter.EvalProvingKey = &evalProvingKey
- voter.SumProvingKey = &sumProvingKey
+ voter.EvalProvingKey = &MarshallableProvingKey{evalProvingKey}
+ voter.SumProvingKey = &MarshallableProvingKey{sumProvingKey}
voter.KeysID = &evt.ID
return true
@@ 670,7 670,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
errorf("our evals verifying key is nil")
return
}
- err := groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit)
+ err := groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk, &publicCircuit)
if err != nil {
warnf("poly eval proof verification failed: %s", err)
return
@@ 827,7 827,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
for i, joinID := range *el.FinalJoinIDs {
evaler := el.Joins[joinID]
hashSeeds[i] = evaler.HashSeed
- if i == voterJoinIDIndex {
+ if uint(i) == voterJoinIDIndex {
hashSelects[i].Assign(1)
} else {
hashSelects[i].Assign(0)
@@ 846,7 846,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
errorf("our sum verifying key is nil")
return
}
- err := groth16.Verify(proof, *el.LocalVoter.SumVerifyingKey, &publicCircuit)
+ err := groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk, &publicCircuit)
if err != nil {
warnf("sum proof verification failed: %s", err)
return
M election/voter.go => election/voter.go +14 -14
@@ 24,7 24,7 @@ type Voter struct {
PubKey [32]byte `json:"pub_key"`
HashSeed string `json:"hash_seed"`
- JoinIDIndex *int `json:"join_id_index,omitempty"`
+ JoinIDIndex *uint `json:"join_id_index,omitempty"`
Output *fr.Element `json:"output,omitempty"`
OutputHashes *[][]byte `json:"output_hashes,omitempty"`
KeysID *id.EventID `json:"keys_id,omitempty"`
@@ 32,26 32,26 @@ type Voter struct {
Sum *fr.Element `json:"sum,omitempty"`
SumID *id.EventID `json:"sum_id,omitempty"`
- EvalProvingKey *groth16.ProvingKey `json:"eval_proving_key"`
- SumProvingKey *groth16.ProvingKey `json:"sum_proving_key"`
+ EvalProvingKey *MarshallableProvingKey `json:"eval_proving_key,omitempty"`
+ SumProvingKey *MarshallableProvingKey `json:"sum_proving_key,omitempty"`
}
type LocalVoter struct {
*Voter
- PrivKey [32]byte `json:"priv_key"`
+ PrivKey [32]byte `json:"priv_key"`
- EvalVerifyingKey *groth16.VerifyingKey `json:"eval_verifying_key"`
- Poly *math.Poly `json:"poly,omitempty"`
- SumVerifyingKey *groth16.VerifyingKey `json:"sum_verifying_key"`
+ EvalVerifyingKey *MarshallableVerifyingKey `json:"eval_verifying_key,omitempty"`
+ Poly *math.Poly `json:"poly,omitempty"`
+ SumVerifyingKey *MarshallableVerifyingKey `json:"sum_verifying_key,omitempty"`
}
-func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed []byte) *Voter {
+func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed string) *Voter {
return &Voter{
Input: *new(fr.Element).Set(input),
JoinEvt: *joinEvt,
PubKey: *pubKey,
- HashSeed: string(hashSeed),
+ HashSeed: hashSeed,
}
}
@@ 251,8 251,8 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
defer el.Save()
defer el.Unlock()
- el.LocalVoter.EvalVerifyingKey = &evalVerifyingKey
- el.LocalVoter.SumVerifyingKey = &sumVerifyingKey
+ el.LocalVoter.EvalVerifyingKey = &MarshallableVerifyingKey{evalVerifyingKey}
+ el.LocalVoter.SumVerifyingKey = &MarshallableVerifyingKey{sumVerifyingKey}
return nil
}
@@ 269,7 269,7 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er
for i, joinID := range *el.FinalJoinIDs {
voter := el.Joins[joinID]
- output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey)
+ output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, voter.EvalProvingKey.Pk)
// encrypt output
var outputBytes [32]byte
@@ 362,7 362,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
ourJoinIDIndex := *el.LocalVoter.JoinIDIndex
for i, joinID := range *el.FinalJoinIDs {
- if i == ourJoinIDIndex {
+ if uint(i) == ourJoinIDIndex {
hashSelects[i].Assign(1)
} else {
hashSelects[i].Assign(0)
@@ 375,7 375,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
for i, joinID := range *el.FinalJoinIDs {
voter := *el.Joins[joinID]
- proof, err := groth16.Prove(r1cs, *voter.SumProvingKey, &witness)
+ proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk, &witness)
if err != nil {
el.RUnlock()
return fmt.Errorf("couldn't prove sum: %s", err)