M cmd/tallyard/main.go => cmd/tallyard/main.go +1 -1
@@ 128,7 128,7 @@ func main() {
for i, joinID := range *el.FinalJoinIDs {
inputs[i] = &el.Joins[joinID].Input
}
- el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeeds(), inputs)
+ el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeed(), inputs)
el.Save()
}
M election/election.go => election/election.go +10 -8
@@ 7,6 7,7 @@ import (
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
+ "tallyard.xyz/math"
)
type Election struct {
@@ 25,7 26,7 @@ type Election struct {
// Save election to disk. Set by the containing ElectionsMap at runtime
Save func() `json:"-"`
- hashSeeds *[]string `json:"-"`
+ hashSeed *string `json:"-"`
}
func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election {
@@ 56,19 57,20 @@ func (el *Election) UnmarshalJSON(b []byte) error {
return nil
}
-func (el *Election) GetHashSeeds() []string {
- if el.hashSeeds != nil {
- return *el.hashSeeds
+func (el *Election) GetHashSeed() string {
+ if el.hashSeed != nil {
+ return *el.hashSeed
}
if el.FinalJoinIDs == nil {
panic("GatHashSeeds called before election started")
}
- hashSeeds := make([]string, len(*el.FinalJoinIDs))
+ seedParts := make([][]byte, len(*el.FinalJoinIDs))
for i, joinID := range *el.FinalJoinIDs {
- hashSeeds[i] = el.Joins[joinID].HashSeed
+ seedParts[i] = el.Joins[joinID].SeedPart
}
- el.hashSeeds = &hashSeeds
- return hashSeeds
+ hashSeed := math.HashSeedFromSeedParts(seedParts)
+ el.hashSeed = &hashSeed
+ return hashSeed
}
M election/marshal_test.go => election/marshal_test.go +15 -12
@@ 40,15 40,16 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
pubKey, privKey, err := box.GenerateKey(rand.Reader)
// HashSeed(s)
- hashSeeds := make([]string, numVoters)
- for i := range hashSeeds {
- var hashSeed [32]byte
- if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil {
+ seedParts := make([][]byte, numVoters)
+ for i := range seedParts {
+ var seedPart [32]byte
+ if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil {
t.Fatal(err)
}
- hashSeeds[i] = base64.StdEncoding.EncodeToString(hashSeed[:])
+ seedParts[i] = seedPart[:]
}
- hashSeed := hashSeeds[joinIDIndex]
+
+ localHashSeed := seedParts[joinIDIndex]
// JoinEvt
joinEvt := &event.Event{
@@ 63,14 64,16 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
CreateID: "createID",
Input: base64.StdEncoding.EncodeToString(inputBytes[:]),
PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]),
- HashSeed: hashSeed,
+ SeedPart: base64.StdEncoding.EncodeToString(localHashSeed),
},
},
}
- localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, hashSeed), privKey)
+ localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, localHashSeed), privKey)
localVoter.JoinIDIndex = &joinIDIndex
+ hashSeed := math.HashSeedFromSeedParts(seedParts)
+
// Output
localVoter.Output, err = new(fr.Element).SetRandom()
if err != nil {
@@ 109,7 112,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
const numCandidates = 5
{
- r1cs, err := math.EvalsCircuitR1CS(hashSeeds, numVoters, numCandidates)
+ r1cs, err := math.EvalsCircuitR1CS(hashSeed, numVoters, numCandidates)
if err != nil {
t.Fatal(err)
}
@@ 125,7 128,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
}
{
- r1cs, err := math.SumCircuitR1CS(hashSeeds, numVoters)
+ r1cs, err := math.SumCircuitR1CS(hashSeed, numVoters)
if err != nil {
t.Fatal(err)
}
@@ 148,7 151,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
}
// Poly
- localVoter.Poly = math.NewRandomPoly(ballot, hashSeeds, inputs)
+ localVoter.Poly = math.NewRandomPoly(ballot, hashSeed, inputs)
return localVoter
}
@@ 184,7 187,7 @@ func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) {
}
// HashSeed
- if lv1.HashSeed != lv2.HashSeed {
+ if bytes.Compare(lv1.SeedPart, lv2.SeedPart) != 0 {
t.Errorf("hash seeds not equal")
}
M election/msg.go => election/msg.go +12 -12
@@ 69,7 69,7 @@ type JoinElectionContent struct {
CreateID id.EventID `json:"create_id"`
Input string `json:"input"`
PubKey string `json:"pub_key"`
- HashSeed string `json:"hash_seed"`
+ SeedPart string `json:"seed_part"`
}
type StartElectionContent struct {
@@ 281,18 281,18 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success
return
}
- // HashSeed
- if content.HashSeed == "" {
- warnf("the hash seed is empty")
+ // SeedPart
+ if content.SeedPart == "" {
+ warnf("the seed part is empty")
return
}
- hashSeedBytes, err := base64.StdEncoding.DecodeString(content.HashSeed)
+ seedPart, err := base64.StdEncoding.DecodeString(content.SeedPart)
if err != nil {
- warnf("we couldn't decode their hash seed: %s", err)
+ warnf("we couldn't decode their seed part: %s", err)
return
}
- if len(hashSeedBytes) < 32 {
- warnf("their hash seed is fewer than 32 bytes")
+ if len(seedPart) < 32 {
+ warnf("their seed part is fewer than 32 bytes")
return
}
@@ 330,7 330,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success
defer el.Save()
defer el.Unlock()
- el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, content.HashSeed)
+ el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, seedPart)
return true
}
@@ 665,7 665,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
var publicCircuit math.EvalsCircuit
// public
- publicCircuit.HashSeeds = el.GetHashSeeds()
+ publicCircuit.HashSeed = el.GetHashSeed()
publicCircuit.Inputs = make([]frontend.Variable, len(*el.FinalJoinIDs))
for i, joinID := range *el.FinalJoinIDs {
publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input)
@@ 695,7 695,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
// ensure our output hashes to the given hash
{
outputBytes := output.Bytes()
- outputHash, err := mimc.Sum(el.LocalVoter.HashSeed, outputBytes[:])
+ outputHash, err := mimc.Sum(el.GetHashSeed(), outputBytes[:])
if err != nil {
errorf("couldn't hash output: %s", err)
return
@@ 847,7 847,7 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
}
outputHashes[i].Assign((*evaler.OutputHashes)[voterJoinIDIndex])
}
- publicCircuit.HashSeeds = el.GetHashSeeds()
+ publicCircuit.HashSeed = el.GetHashSeed()
publicCircuit.HashSelects = hashSelects
publicCircuit.OutputHashes = outputHashes
publicCircuit.Sum.Assign(sum)
M election/voter.go => election/voter.go +8 -8
@@ 23,7 23,7 @@ type Voter struct {
Input fr.Element `json:"input"`
JoinEvt event.Event `json:"join_evt"`
PubKey [32]byte `json:"pub_key"`
- HashSeed string `json:"hash_seed"`
+ SeedPart []byte `json:"seed_part"`
JoinIDIndex *uint `json:"join_id_index,omitempty"`
Output *fr.Element `json:"output,omitempty"`
@@ 47,12 47,12 @@ type LocalVoter struct {
SumVerifyingKey *MarshallableVerifyingKey `json:"sum_verifying_key,omitempty"`
}
-func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, hashSeed string) *Voter {
+func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, seedPart []byte) *Voter {
return &Voter{
Input: *new(fr.Element).Set(input),
JoinEvt: *joinEvt,
PubKey: *pubKey,
- HashSeed: hashSeed,
+ SeedPart: seedPart,
}
}
@@ 104,7 104,7 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore)
CreateID: el.CreateEvt.ID,
Input: base64.StdEncoding.EncodeToString(inputBytes[:]),
PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]),
- HashSeed: base64.StdEncoding.EncodeToString(seedBytes[:]),
+ SeedPart: base64.StdEncoding.EncodeToString(seedBytes[:]),
})
if err != nil {
return fmt.Errorf("couldn't send join messages: %s", err)
@@ 171,7 171,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
evalVerifyingKey groth16.VerifyingKey
)
{
- r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs), len(el.Candidates))
+ r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs), len(el.Candidates))
if err != nil {
return fmt.Errorf("couldn't compile eval circuit: %s", err)
}
@@ 200,7 200,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
sumVerifyingKey groth16.VerifyingKey
)
{
- r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))
+ r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))
if err != nil {
return fmt.Errorf("couldn't compile sum circuit: %s", err)
}
@@ 346,7 346,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
evalsIDs[i] = *voter.EvalsID
}
- r1cs, err := math.SumCircuitR1CS(el.GetHashSeeds(), len(*el.FinalJoinIDs))
+ r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))
proofs := make([]string, numVoters)
{
@@ 355,7 355,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
outputs := make([]frontend.Variable, numVoters)
var witness math.SumCircuit
- witness.HashSeeds = el.GetHashSeeds()
+ witness.HashSeed = el.GetHashSeed()
witness.HashSelects = hashSelects
witness.OutputHashes = outputHashes
witness.Sum.Assign(el.LocalVoter.Sum)
A math/hash.go => math/hash.go +11 -0
@@ 0,0 1,11 @@
+package math
+
+import "golang.org/x/crypto/sha3"
+
+func HashSeedFromSeedParts(seedParts [][]byte) string {
+ res := sha3.Sum256(seedParts[0])
+ for _, part := range seedParts[1:] {
+ res = sha3.Sum256(append(res[:], part...))
+ }
+ return string(res[:])
+}
M math/lagrange_test.go => math/lagrange_test.go +6 -7
@@ 2,7 2,6 @@ package math
import (
"crypto/rand"
- "encoding/base64"
"io"
"testing"
@@ 66,16 65,16 @@ func TestRandomPoly(t *testing.T) {
}
// hash seeds don't matter
- hashSeeds := make([]string, n)
- for i := range hashSeeds {
- var hashSeed [32]byte
- if _, err := io.ReadFull(rand.Reader, hashSeed[:]); err != nil {
+ seedParts := make([][]byte, n)
+ for i := range seedParts {
+ var seedPart [32]byte
+ if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil {
t.Fatal(err)
}
- hashSeeds[i] = base64.RawStdEncoding.EncodeToString(hashSeed[:])
+ seedParts[i] = seedPart[:]
}
- poly := NewRandomPoly(ballot, hashSeeds, inputs)
+ poly := NewRandomPoly(ballot, HashSeedFromSeedParts(seedParts), inputs)
points := make([]Point, n)
for i := 0; i < n; i++ {
genx:
M math/poly.go => math/poly.go +10 -10
@@ 10,10 10,10 @@ import (
)
type Poly struct {
- Ballot [][]byte `json:"ballot"`
- Coeffs []*fr.Element `json:"coeffs"`
- HashSeeds []string `json:"hash_seeds"`
- Inputs []*fr.Element `json:"inputs"`
+ Ballot [][]byte `json:"ballot"`
+ Coeffs []*fr.Element `json:"coeffs"`
+ HashSeed string `json:"hash_seed"`
+ Inputs []*fr.Element `json:"inputs"`
cache *polyCache `json:"-"`
}
@@ 33,7 33,7 @@ type outputCache struct {
// NewRandomPoly generates a random polynomial of degree numVoters-1, using
// `ballot' as the constant term
-func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *Poly {
+func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly {
coeffs := make([]*fr.Element, len(inputs) - 1)
var err error
@@ 44,7 44,7 @@ func NewRandomPoly(ballot [][]byte, hashSeeds []string, inputs []*fr.Element) *P
}
}
- return &Poly{ballot, coeffs, hashSeeds, inputs, nil}
+ return &Poly{ballot, coeffs, hashSeed, inputs, nil}
}
func (p *Poly) setupCache() {
@@ 72,10 72,10 @@ func (p *Poly) setupCache() {
// calculate outputs and their hashes
p.cache.outputHashes = make(map[*fr.Element]outputCache)
- for i, input := range p.Inputs {
+ for _, input := range p.Inputs {
output := p.eval(input)
byts := output.Bytes()
- hash, err := mimc.Sum(p.HashSeeds[i], byts[:])
+ hash, err := mimc.Sum(p.HashSeed, byts[:])
if err != nil {
panic(err)
}
@@ 85,7 85,7 @@ func (p *Poly) setupCache() {
// compile R1CS
{
var err error
- p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeeds, len(p.Inputs), len(p.Ballot))
+ p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeed, len(p.Inputs), len(p.Ballot))
if err != nil {
panic(err)
}
@@ 95,7 95,7 @@ func (p *Poly) setupCache() {
{
witness := &p.cache.witness
// public
- witness.HashSeeds = p.HashSeeds
+ witness.HashSeed = p.HashSeed
witness.Inputs = make([]frontend.Variable, len(p.Inputs))
witness.OutputHashes = make([]frontend.Variable, len(p.Inputs))
for i, input := range p.Inputs {
M math/zk.go => math/zk.go +13 -19
@@ 10,7 10,7 @@ import (
)
type EvalsCircuit struct {
- HashSeeds []string
+ HashSeed string
Inputs []frontend.Variable `gnark:",public"`
OutputHashes []frontend.Variable `gnark:",public"`
@@ 22,9 22,6 @@ type EvalsCircuit struct {
}
func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
- if len(circuit.HashSeeds) != len(circuit.Inputs) {
- return errors.New("len(circuit.HashSeeds) != len(circuit.Inputs)")
- }
if len(circuit.Inputs) != len(circuit.OutputHashes) {
return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)")
}
@@ 47,7 44,7 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSyste
}
// hash
- mimc, err := mimc.NewMiMC(circuit.HashSeeds[i], ecc.BLS12_381)
+ mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381)
if err != nil {
return err
}
@@ 72,9 69,9 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSyste
return nil
}
-func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
+func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
var circuit EvalsCircuit
- circuit.HashSeeds = hashSeeds
+ circuit.HashSeed = hashSeed
circuit.Inputs = make([]frontend.Variable, numVoters)
circuit.OutputHashes = make([]frontend.Variable, numVoters)
circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates)
@@ 87,7 84,7 @@ func EvalsCircuitR1CS(hashSeeds []string, numVoters int, numCandidates int) (fro
}
type SumCircuit struct {
- HashSeeds []string
+ HashSeed string
HashSelects []frontend.Variable `gnark:",public"`
OutputHashes []frontend.Variable `gnark:",public"`
Sum frontend.Variable `gnark:",public"`
@@ 96,9 93,6 @@ type SumCircuit struct {
}
func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error {
- if len(circuit.HashSeeds) != len(circuit.HashSelects) {
- return errors.New("len(circuit.HashSeeds) != len(circuit.HashSelects)")
- }
if len(circuit.HashSelects) != len(circuit.OutputHashes) {
return errors.New("len(circuit.HashSelects) != len(circuit.OutputHashes)")
}
@@ 108,13 102,13 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem)
// prove the secret outputs hash to public hashes
zero := cs.Constant(0)
+
+ mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381)
+ if err != nil {
+ return err
+ }
// TODO: this seems verify inefficient
- for i, hashSeed := range circuit.HashSeeds {
- mimc, err := mimc.NewMiMC(hashSeed, ecc.BLS12_381)
- if err != nil {
- return err
- }
- bit := circuit.HashSelects[i]
+ for _, bit := range circuit.HashSelects {
for j, output := range circuit.Outputs {
cs.AssertIsEqual(
cs.Select(bit, mimc.Hash(cs, output), zero),
@@ 134,9 128,9 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem)
return nil
}
-func SumCircuitR1CS(hashSeeds []string, numVoters int) (frontend.CompiledConstraintSystem, error) {
+func SumCircuitR1CS(hashSeed string, numVoters int) (frontend.CompiledConstraintSystem, error) {
var circuit SumCircuit
- circuit.HashSeeds = hashSeeds
+ circuit.HashSeed = hashSeed
circuit.HashSelects = make([]frontend.Variable, numVoters)
circuit.OutputHashes = make([]frontend.Variable, numVoters)
circuit.Outputs = make([]frontend.Variable, numVoters)