From 6c109fa5b52b060ecae8ebf8958be7ec2f6799bb Mon Sep 17 00:00:00 2001 From: David Florness Date: Sat, 8 May 2021 21:05:07 -0400 Subject: [PATCH] Fix voter JSON marshalling --- election/marshal.go | 67 ++++++++++ election/marshal_test.go | 258 +++++++++++++++++++++++++++++++++++++++ election/msg.go | 18 +-- election/voter.go | 28 ++--- 4 files changed, 348 insertions(+), 23 deletions(-) create mode 100644 election/marshal.go create mode 100644 election/marshal_test.go diff --git a/election/marshal.go b/election/marshal.go new file mode 100644 index 0000000..08d3eb4 --- /dev/null +++ b/election/marshal.go @@ -0,0 +1,67 @@ +package election + +import ( + "bytes" + "encoding/base64" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" +) + +type MarshallableProvingKey struct { + Pk groth16.ProvingKey +} + +func (t *MarshallableProvingKey) UnmarshalJSON(b []byte) error { + var buf bytes.Buffer + _, err := buf.Write(bytes.Trim(b, "\"")) + if err != nil { + return err + } + t.Pk = groth16.NewProvingKey(ecc.BLS12_381) + _, err = t.Pk.ReadFrom(base64.NewDecoder(base64.StdEncoding, &buf)) + return err +} + +func (t *MarshallableProvingKey) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + _, err := t.Pk.WriteTo(encoder) + if err != nil { + return nil, err + } + err = encoder.Close() + if err != nil { + return nil, err + } + return append([]byte{'"'}, append(buf.Bytes(), '"')...), nil +} + +type MarshallableVerifyingKey struct { + Vk groth16.VerifyingKey +} + +func (t *MarshallableVerifyingKey) UnmarshalJSON(b []byte) error { + var buf bytes.Buffer + _, err := buf.Write(bytes.Trim(b, "\"")) + if err != nil { + return err + } + t.Vk = groth16.NewVerifyingKey(ecc.BLS12_381) + _, err = t.Vk.ReadFrom(base64.NewDecoder(base64.StdEncoding, &buf)) + return err +} + +func (t *MarshallableVerifyingKey) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + _, err := t.Vk.WriteRawTo(encoder) + if err != nil { + return nil, err + } + err = encoder.Close() + if err != nil { + return nil, err + } + return append([]byte{'"'}, append(buf.Bytes(), '"')...), nil +} diff --git a/election/marshal_test.go b/election/marshal_test.go new file mode 100644 index 0000000..a85d400 --- /dev/null +++ b/election/marshal_test.go @@ -0,0 +1,258 @@ +package election + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "encoding/json" + "io" + r "math/rand" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark/backend/groth16" + "golang.org/x/crypto/nacl/box" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "tallyard.xyz/math" +) + +func randomLocalVoter(t *testing.T) *LocalVoter { + // JoinIDIndex + const numVoters = 10 + joinIDIndex := uint(r.Intn(numVoters)) + + // Input(s) + inputs := make([]*fr.Element, numVoters) + for i := range inputs { + var err error + inputs[i], err = new(fr.Element).SetRandom() + if err != nil { + t.Fatal(err) + } + } + input := inputs[joinIDIndex] + inputBytes := input.Bytes() + + // PubKey / PrivKey + pubKey, privKey, err := box.GenerateKey(rand.Reader) + + // HashSeed(s) + hashSeeds := make([]string, numVoters) + for i := range hashSeeds { + var hashSeed [32]byte + if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil { + t.Fatal(err) + } + hashSeeds[i] = base64.StdEncoding.EncodeToString(hashSeed[:]) + } + hashSeed := hashSeeds[joinIDIndex] + + // JoinEvt + joinEvt := &event.Event{ + Sender: "@test:example.org", + Timestamp: time.Now().Unix(), + Type: JoinElectionMessage, + ID: "event1", + RoomID: "room1", + Content: event.Content{ + Parsed: JoinElectionContent{ + Version: Version, + CreateID: "createID", + Input: base64.StdEncoding.EncodeToString(inputBytes[:]), + PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]), + HashSeed: hashSeed, + }, + }, + } + + localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, hashSeed), privKey) + localVoter.JoinIDIndex = &joinIDIndex + + // Output + localVoter.Output, err = new(fr.Element).SetRandom() + if err != nil { + t.Fatal(err) + } + + // OutputHashes + outputHashes := make([][]byte, numVoters) + for i := range outputHashes { + var outputHash [32]byte + if _, err := io.ReadFull(rand.Reader, outputHash[:]); err != nil { + t.Fatal(err) + } + outputHashes[i] = outputHash[:] + } + localVoter.OutputHashes = &outputHashes + + // KeysID + keysID := id.EventID("keysID") + localVoter.KeysID = &keysID + + // EvalsID + evalsID := id.EventID("evalsID") + localVoter.EvalsID = &evalsID + + // Sum + localVoter.Sum, err = new(fr.Element).SetRandom() + if err != nil { + t.Fatal(err) + } + + // SumID + sumID := id.EventID("sumID") + localVoter.SumID = &sumID + + const numCandidates = 5 + + { + r1cs, err := math.EvalsCircuitR1CS(hashSeeds, numVoters, numCandidates) + if err != nil { + t.Fatal(err) + } + evalProvingKey, evalVerifyingKey, err := groth16.Setup(r1cs) + if err != nil { + t.Fatal(err) + } + + // EvalProvingKey + localVoter.EvalProvingKey = &MarshallableProvingKey{evalProvingKey} + // EvalVerifyingKey + localVoter.EvalVerifyingKey = &MarshallableVerifyingKey{evalVerifyingKey} + } + + { + r1cs, err := math.SumCircuitR1CS(hashSeeds, numVoters) + if err != nil { + t.Fatal(err) + } + sumProvingKey, sumVerifyingKey, err := groth16.Setup(r1cs) + if err != nil { + t.Fatal(err) + } + // SumProvingKey + localVoter.SumProvingKey = &MarshallableProvingKey{sumProvingKey} + // SumVerifyingKey + localVoter.SumVerifyingKey = &MarshallableVerifyingKey{sumVerifyingKey} + } + + ballot := make([][]byte, numCandidates) + for i := range ballot { + ballot[i] = make([]byte, numCandidates) + for j := range ballot[i] { + ballot[i][j] = byte(r.Intn(2)) + } + } + + // Poly + localVoter.Poly = math.NewRandomPoly(ballot, hashSeeds, inputs) + + return localVoter +} + +func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) { + // Input + if !lv1.Input.Equal(&lv2.Input) { + t.Errorf("inputs not equal: %s vs %s", &lv1.Input, &lv2.Input) + } + + // JoinEvt + if lv2.JoinEvt.Sender != lv1.JoinEvt.Sender { + t.Error("join event senders not equal") + } + if lv2.JoinEvt.Timestamp != lv1.JoinEvt.Timestamp { + t.Error("join event timestamps not equal") + } + // we don't yet parse the type right away + // if unmarshalledLocalVoter.JoinEvt.Type != localVoter.JoinEvt.Type { + // t.Error("join event types not equal") + // } + if lv2.JoinEvt.ID != lv1.JoinEvt.ID { + t.Error("join event IDs not equal") + } + if lv2.JoinEvt.RoomID != lv1.JoinEvt.RoomID { + t.Error("join event RoomIDs not equal") + } + // TODO: JoinEvt.Content + + // PubKey + if bytes.Compare(lv1.PubKey[:], lv2.PubKey[:]) != 0 { + t.Errorf("pubkeys not equal") + } + + // HashSeed + if lv1.HashSeed != lv2.HashSeed { + t.Errorf("hash seeds not equal") + } + + // JoinIDIndex + if *lv2.JoinIDIndex != *lv1.JoinIDIndex { + t.Error("join ID indices not equal") + } + + // Output + if !lv2.Output.Equal(lv1.Output) { + t.Error("outputs not equal") + } + + // OutputHashes + for i, outputHash := range *lv2.OutputHashes { + if !bytes.Equal(outputHash, (*lv1.OutputHashes)[i]) { + t.Error("output hashes not equal") + } + } + + // KeysID + if *lv2.KeysID != *lv1.KeysID { + t.Error("keys IDs not equal") + } + // EvalsID + if *lv2.EvalsID != *lv1.EvalsID { + t.Error("evals IDs not equal") + } + // Sum + if !lv2.Sum.Equal(lv1.Sum) { + t.Error("sums not equal") + } + // SumID + if *lv2.SumID != *lv1.SumID { + t.Error("sum IDs not equal") + } + + // EvalProvingKey + if lv2.EvalProvingKey.Pk.IsDifferent(lv1.EvalProvingKey.Pk) { + t.Error("eval proving keys not equal") + } + // EvalVerifyingKey + if lv2.EvalVerifyingKey.Vk.IsDifferent(lv1.EvalVerifyingKey.Vk) { + t.Error("eval verifying keys not equal") + } + // SumProvingKey + if lv2.SumProvingKey.Pk.IsDifferent(lv1.SumProvingKey.Pk) { + t.Error("sum proving keys not equal") + } + // SumVerifyingKey + if lv2.SumVerifyingKey.Vk.IsDifferent(lv1.SumVerifyingKey.Vk) { + t.Error("sum verifying keys not equal") + } +} + +func TestLocalVoterMarshal(t *testing.T) { + localVoter := randomLocalVoter(t) + + b, err := json.Marshal(localVoter) + if err != nil { + t.Fatal(err) + } + + unmarshalledLocalVoter := new(LocalVoter) + + err = json.Unmarshal(b, &unmarshalledLocalVoter) + if err != nil { + t.Fatal(err) + } + + ensureEqual(t, localVoter, unmarshalledLocalVoter) +} diff --git a/election/msg.go b/election/msg.go index d9f39e0..0491dbb 100644 --- a/election/msg.go +++ b/election/msg.go @@ -281,12 +281,12 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success warnf("the hash seed is empty") return } - hashSeed, err := base64.StdEncoding.DecodeString(content.HashSeed) + hashSeedBytes, err := base64.StdEncoding.DecodeString(content.HashSeed) if err != nil { warnf("we couldn't decode their hash seed: %s", err) return } - if len(hashSeed) < 32 { + if len(hashSeedBytes) < 32 { warnf("their hash seed is fewer than 32 bytes") return } @@ -325,7 +325,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success defer el.Save() defer el.Unlock() - el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, hashSeed) + el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, content.HashSeed) return true } @@ -389,7 +389,7 @@ func (elections *ElectionsMap) onStartElectionMessage(evt *event.Event) (success el.FinalJoinIDs = &content.JoinIDs for i, joinID := range *el.FinalJoinIDs { // TODO: gross - joinIndex := i + var joinIndex uint = uint(i) el.Joins[joinID].JoinIDIndex = &joinIndex } @@ -500,8 +500,8 @@ func (elections *ElectionsMap) onKeysMessage(evt *event.Event, client *mautrix.C defer el.Save() defer el.Unlock() - voter.EvalProvingKey = &evalProvingKey - voter.SumProvingKey = &sumProvingKey + voter.EvalProvingKey = &MarshallableProvingKey{evalProvingKey} + voter.SumProvingKey = &MarshallableProvingKey{sumProvingKey} voter.KeysID = &evt.ID return true @@ -670,7 +670,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { errorf("our evals verifying key is nil") return } - err := groth16.Verify(proof, *el.LocalVoter.EvalVerifyingKey, &publicCircuit) + err := groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk, &publicCircuit) if err != nil { warnf("poly eval proof verification failed: %s", err) return @@ -827,7 +827,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) { for i, joinID := range *el.FinalJoinIDs { evaler := el.Joins[joinID] hashSeeds[i] = evaler.HashSeed - if i == voterJoinIDIndex { + if uint(i) == voterJoinIDIndex { hashSelects[i].Assign(1) } else { hashSelects[i].Assign(0) @@ -846,7 +846,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) { errorf("our sum verifying key is nil") return } - err := groth16.Verify(proof, *el.LocalVoter.SumVerifyingKey, &publicCircuit) + err := groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk, &publicCircuit) if err != nil { warnf("sum proof verification failed: %s", err) return diff --git a/election/voter.go b/election/voter.go index f3eb508..f124ee3 100644 --- a/election/voter.go +++ b/election/voter.go @@ -24,7 +24,7 @@ type Voter struct { PubKey [32]byte `json:"pub_key"` HashSeed string `json:"hash_seed"` - JoinIDIndex *int `json:"join_id_index,omitempty"` + JoinIDIndex *uint `json:"join_id_index,omitempty"` Output *fr.Element `json:"output,omitempty"` OutputHashes *[][]byte `json:"output_hashes,omitempty"` KeysID *id.EventID `json:"keys_id,omitempty"` @@ -32,26 +32,26 @@ type Voter struct { Sum *fr.Element `json:"sum,omitempty"` SumID *id.EventID `json:"sum_id,omitempty"` - EvalProvingKey *groth16.ProvingKey `json:"eval_proving_key"` - SumProvingKey *groth16.ProvingKey `json:"sum_proving_key"` + EvalProvingKey *MarshallableProvingKey `json:"eval_proving_key,omitempty"` + SumProvingKey *MarshallableProvingKey `json:"sum_proving_key,omitempty"` } type LocalVoter struct { *Voter - PrivKey [32]byte `json:"priv_key"` + PrivKey [32]byte `json:"priv_key"` - EvalVerifyingKey *groth16.VerifyingKey `json:"eval_verifying_key"` - Poly *math.Poly `json:"poly,omitempty"` - SumVerifyingKey *groth16.VerifyingKey `json:"sum_verifying_key"` + EvalVerifyingKey *MarshallableVerifyingKey `json:"eval_verifying_key,omitempty"` + Poly *math.Poly `json:"poly,omitempty"` + SumVerifyingKey *MarshallableVerifyingKey `json:"sum_verifying_key,omitempty"` } -func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed []byte) *Voter { +func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed string) *Voter { return &Voter{ Input: *new(fr.Element).Set(input), JoinEvt: *joinEvt, PubKey: *pubKey, - HashSeed: string(hashSeed), + HashSeed: hashSeed, } } @@ -251,8 +251,8 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto defer el.Save() defer el.Unlock() - el.LocalVoter.EvalVerifyingKey = &evalVerifyingKey - el.LocalVoter.SumVerifyingKey = &sumVerifyingKey + el.LocalVoter.EvalVerifyingKey = &MarshallableVerifyingKey{evalVerifyingKey} + el.LocalVoter.SumVerifyingKey = &MarshallableVerifyingKey{sumVerifyingKey} return nil } @@ -269,7 +269,7 @@ func (el *Election) SendEvals(client *mautrix.Client, eventStore *EventStore) er for i, joinID := range *el.FinalJoinIDs { voter := el.Joins[joinID] - output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, *voter.EvalProvingKey) + output, outputHash, proof := el.LocalVoter.Poly.EvalAndProve(&voter.Input, voter.EvalProvingKey.Pk) // encrypt output var outputBytes [32]byte @@ -362,7 +362,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro ourJoinIDIndex := *el.LocalVoter.JoinIDIndex for i, joinID := range *el.FinalJoinIDs { - if i == ourJoinIDIndex { + if uint(i) == ourJoinIDIndex { hashSelects[i].Assign(1) } else { hashSelects[i].Assign(0) @@ -375,7 +375,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro for i, joinID := range *el.FinalJoinIDs { voter := *el.Joins[joinID] - proof, err := groth16.Prove(r1cs, *voter.SumProvingKey, &witness) + proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk, &witness) if err != nil { el.RUnlock() return fmt.Errorf("couldn't prove sum: %s", err) -- 2.38.4