~edwargix/tallyard

8d759a826f4593ba204626203019e34fe669d76e — David Florness 2 years ago f151fc7
Upgrade gnark from v0.5.2 to v0.6.4

gnark v0.6.0 contains many breaking changes
https://github.com/ConsenSys/gnark/blob/master/CHANGELOG.md#v060---2022-01-03

The tallyard version is also bumped to v0.5.0 since the SeedPart param in the
JoinElection message (xyz.tallyard.join) has been dropped.
M cmd/tallyard/main.go => cmd/tallyard/main.go +1 -1
@@ 174,7 174,7 @@ func main() {
		for i, joinID := range *el.FinalJoinIDs {
			inputs[i] = &el.Joins[joinID].Input
		}
		el.LocalVoter.Poly = math.NewRandomPoly(*ballot, el.GetHashSeed(), inputs)
		el.LocalVoter.Poly = math.NewRandomPoly(*ballot, inputs)
		el.Save()
	}


M election/election.go => election/election.go +0 -21
@@ 7,7 7,6 @@ import (

	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"
	"tallyard.xyz/math"
)

type Election struct {


@@ 25,8 24,6 @@ type Election struct {

	// Save election to disk.  Set by the containing ElectionsMap at runtime
	Save         func()                `json:"-"`

	hashSeed     *string               `json:"-"`
}

func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election {


@@ 56,21 53,3 @@ func (el *Election) UnmarshalJSON(b []byte) error {
	}
	return nil
}

func (el *Election) GetHashSeed() string {
	if el.hashSeed != nil {
		return *el.hashSeed
	}

	if el.FinalJoinIDs == nil {
		panic("GatHashSeeds called before election started")
	}

	seedParts := make([][]byte, len(*el.FinalJoinIDs))
	for i, joinID := range *el.FinalJoinIDs {
		seedParts[i] = el.Joins[joinID].SeedPart
	}
	hashSeed := math.HashSeedFromSeedParts(seedParts)
	el.hashSeed = &hashSeed
	return hashSeed
}

M election/marshal_test.go => election/marshal_test.go +4 -24
@@ 39,18 39,6 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	// PubKey / PrivKey
	pubKey, privKey, err := box.GenerateKey(rand.Reader)

	// HashSeed(s)
	seedParts := make([][]byte, numVoters)
	for i := range seedParts {
		var seedPart [32]byte
		if _, err := io.ReadFull(rand.Reader, seedPart[:]); err != nil {
			t.Fatal(err)
		}
		seedParts[i] = seedPart[:]
	}

	localHashSeed := seedParts[joinIDIndex]

	// JoinEvt
	joinEvt := &event.Event{
		Sender:    "@test:example.org",


@@ 64,16 52,13 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
				CreateID: "createID",
				Input:    base64.StdEncoding.EncodeToString(inputBytes[:]),
				PubKey:   base64.StdEncoding.EncodeToString((*pubKey)[:]),
				SeedPart: base64.StdEncoding.EncodeToString(localHashSeed),
			},
		},
	}

	localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey, localHashSeed), privKey)
	localVoter := NewLocalVoter(NewVoter(input, joinEvt, pubKey), privKey)
	localVoter.JoinIDIndex = &joinIDIndex

	hashSeed := math.HashSeedFromSeedParts(seedParts)

	// Output
	localVoter.Output, err = new(fr.Element).SetRandom()
	if err != nil {


@@ 112,7 97,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	const numCandidates = 5

	{
		r1cs, err := math.EvalsCircuitR1CS(hashSeed, numVoters, numCandidates)
		r1cs, err := math.EvalsCircuitR1CS(numVoters, numCandidates)
		if err != nil {
			t.Fatal(err)
		}


@@ 128,7 113,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	}

	{
		r1cs, err := math.SumCircuitR1CS(hashSeed, numVoters)
		r1cs, err := math.SumCircuitR1CS(numVoters)
		if err != nil {
			t.Fatal(err)
		}


@@ 151,7 136,7 @@ func randomLocalVoter(t *testing.T) *LocalVoter {
	}

	// Poly
	localVoter.Poly = math.NewRandomPoly(ballot, hashSeed, inputs)
	localVoter.Poly = math.NewRandomPoly(ballot, inputs)

	return localVoter
}


@@ 186,11 171,6 @@ func ensureEqual(t *testing.T, lv1 *LocalVoter, lv2 *LocalVoter) {
		t.Errorf("pubkeys not equal")
	}

	// HashSeed
	if bytes.Compare(lv1.SeedPart, lv2.SeedPart) != 0 {
		t.Errorf("hash seeds not equal")
	}

	// JoinIDIndex
	if *lv2.JoinIDIndex != *lv1.JoinIDIndex {
		t.Error("join ID indices not equal")

M election/msg.go => election/msg.go +20 -28
@@ 69,7 69,6 @@ type JoinElectionContent struct {
	CreateID   id.EventID `json:"create_id"`
	Input      string     `json:"input"`
	PubKey     string     `json:"pub_key"`
	SeedPart   string     `json:"seed_part"`
}

type StartElectionContent struct {


@@ 305,21 304,6 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success 
		return
	}

	// SeedPart
	if content.SeedPart == "" {
		warnf("the seed part is empty")
		return
	}
	seedPart, err := base64.StdEncoding.DecodeString(content.SeedPart)
	if err != nil {
		warnf("we couldn't decode their seed part: %s", err)
		return
	}
	if len(seedPart) < 32 {
		warnf("their seed part is fewer than 32 bytes")
		return
	}

	// Input
	byts, err := base64.StdEncoding.DecodeString(content.Input)
	if err != nil {


@@ 354,7 338,7 @@ func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success 
	defer el.Save()
	defer el.Unlock()

	el.Joins[evt.ID] = NewVoter(input, evt, &pubKey, seedPart)
	el.Joins[evt.ID] = NewVoter(input, evt, &pubKey)

	return true
}


@@ 740,14 724,13 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
		var publicCircuit math.EvalsCircuit

		// public
		publicCircuit.HashSeed = el.GetHashSeed()
		publicCircuit.Inputs = make([]frontend.Variable, len(FinalJoinIDs))
		for i, joinID := range FinalJoinIDs {
			publicCircuit.Inputs[i].Assign(&el.Joins[joinID].Input)
			publicCircuit.Inputs[i] = &el.Joins[joinID].Input
		}
		publicCircuit.OutputHashes = make([]frontend.Variable, len(FinalJoinIDs))
		for i, outputHash := range outputHashes {
			publicCircuit.OutputHashes[i].Assign(outputHash)
			publicCircuit.OutputHashes[i] = outputHash
		}

		// private


@@ 760,7 743,12 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
			errorf("our evals verifying key is nil")
			return
		}
		err := groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk(), &publicCircuit)
		witness, err := frontend.NewWitness(&publicCircuit, ecc.BLS12_381, frontend.PublicOnly())
		if err != nil {
			errorf("couldn't create witness: %w", err)
			return
		}
		err = groth16.Verify(proof, el.LocalVoter.EvalVerifyingKey.Vk(), witness)
		if err != nil {
			warnf("poly eval proof verification failed: %s", err)
			return


@@ 770,7 758,7 @@ func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
	// ensure our output hashes to the given hash
	{
		outputBytes := output.Bytes()
		outputHash, err := mimc.Sum(el.GetHashSeed(), outputBytes[:])
		outputHash, err := mimc.Sum(outputBytes[:])
		if err != nil {
			errorf("couldn't hash output: %s", err)
			return


@@ 939,16 927,15 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
		for i, joinID := range *el.FinalJoinIDs {
			evaler := el.Joins[joinID]
			if uint(i) == voterJoinIDIndex {
				hashSelects[i].Assign(1)
				hashSelects[i] = 1
			} else {
				hashSelects[i].Assign(0)
				hashSelects[i] = 0
			}
			outputHashes[i].Assign((*evaler.OutputHashes)[voterJoinIDIndex])
			outputHashes[i] = (*evaler.OutputHashes)[voterJoinIDIndex]
		}
		publicCircuit.HashSeed = el.GetHashSeed()
		publicCircuit.HashSelects = hashSelects
		publicCircuit.OutputHashes = outputHashes
		publicCircuit.Sum.Assign(sum)
		publicCircuit.Sum = sum
		publicCircuit.Outputs = make([]frontend.Variable, n)

		if el.LocalVoter.SumVerifyingKey == nil {


@@ 957,7 944,12 @@ func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
			errorf("our sum verifying key is nil")
			return
		}
		err := groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk(), &publicCircuit)
		witness, err := frontend.NewWitness(&publicCircuit, ecc.BLS12_381, frontend.PublicOnly())
		if err != nil {
			errorf("couldn't create witness: %w", err)
			return
		}
		err = groth16.Verify(proof, el.LocalVoter.SumVerifyingKey.Vk(), witness)
		if err != nil {
			warnf("sum proof verification failed: %s", err)
			return

M election/voter.go => election/voter.go +20 -23
@@ 8,6 8,7 @@ import (
	"fmt"
	"io"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
	"github.com/consensys/gnark/backend/groth16"
	"github.com/consensys/gnark/frontend"


@@ 23,7 24,6 @@ type Voter struct {
	Input    fr.Element  `json:"input"`
	JoinEvt  event.Event `json:"join_evt"`
	PubKey   [32]byte    `json:"pub_key"`
	SeedPart []byte      `json:"seed_part"`

	JoinIDIndex  *uint       `json:"join_id_index,omitempty"`
	Output       *fr.Element `json:"output,omitempty"`


@@ 47,12 47,11 @@ type LocalVoter struct {
	SumVerifyingKey  *VerifyingKeyFile `json:"sum_verifying_key,omitempty"`
}

func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte, seedPart []byte) *Voter {
func NewVoter(input *fr.Element, joinEvt *event.Event, pubKey *[32]byte) *Voter {
	return &Voter{
		Input:    *new(fr.Element).Set(input),
		JoinEvt:  *joinEvt,
		PubKey:   *pubKey,
		SeedPart: seedPart,
	}
}



@@ 93,11 92,6 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore)
	}
	inputBytes := input.Bytes()

	var seedBytes [32]byte
	if _, err := io.ReadFull(rand.Reader, seedBytes[:]); err != nil {
		return fmt.Errorf("couldn't read random bytes: %s", err)
	}

	commitmentHash, err := CalculateCommitment(el.CreateEvt.Content)
	if err != nil {
		return fmt.Errorf("couldn't calculate join event's commitment: %s", err)


@@ 110,7 104,6 @@ func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore)
		CreateID:   el.CreateEvt.ID,
		Input:      base64.StdEncoding.EncodeToString(inputBytes[:]),
		PubKey:     base64.StdEncoding.EncodeToString((*pubKey)[:]),
		SeedPart:   base64.StdEncoding.EncodeToString(seedBytes[:]),
	})
	if err != nil {
		return fmt.Errorf("couldn't send join messages: %s", err)


@@ 192,7 185,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
		evalVerifyingKey  groth16.VerifyingKey
	)
	{
		r1cs, err := math.EvalsCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs), len(el.Candidates))
		r1cs, err := math.EvalsCircuitR1CS(len(*el.FinalJoinIDs), len(el.Candidates))
		if err != nil {
			return fmt.Errorf("couldn't compile eval circuit: %s", err)
		}


@@ 222,7 215,7 @@ func (el *Election) SendProvingKeys(client *mautrix.Client, eventStore *EventSto
		sumVerifyingKey  groth16.VerifyingKey
	)
	{
		r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))
		r1cs, err := math.SumCircuitR1CS(len(*el.FinalJoinIDs))
		if err != nil {
			return fmt.Errorf("couldn't compile sum circuit: %s", err)
		}


@@ 399,7 392,7 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
		evalsIDs[i] = *voter.EvalsID
	}

	r1cs, err := math.SumCircuitR1CS(el.GetHashSeed(), len(*el.FinalJoinIDs))
	r1cs, err := math.SumCircuitR1CS(len(*el.FinalJoinIDs))

	proofs := make([]string, numVoters)
	{


@@ 407,29 400,33 @@ func (el *Election) SendSum(client *mautrix.Client, eventStore *EventStore) erro
		outputHashes := make([]frontend.Variable, numVoters)
		outputs := make([]frontend.Variable, numVoters)

		var witness math.SumCircuit
		witness.HashSeed = el.GetHashSeed()
		witness.HashSelects = hashSelects
		witness.OutputHashes = outputHashes
		witness.Sum.Assign(el.LocalVoter.Sum)
		witness.Outputs = outputs
		var circuit math.SumCircuit
		circuit.HashSelects = hashSelects
		circuit.OutputHashes = outputHashes
		circuit.Sum = el.LocalVoter.Sum
		circuit.Outputs = outputs

		ourJoinIDIndex := *el.LocalVoter.JoinIDIndex
		for i, joinID := range *el.FinalJoinIDs {
			if uint(i) == ourJoinIDIndex {
				hashSelects[i].Assign(1)
				hashSelects[i] = 1
			} else {
				hashSelects[i].Assign(0)
				hashSelects[i] = 0
			}
			evaler := el.Joins[joinID]
			outputHashForUs := (*evaler.OutputHashes)[ourJoinIDIndex]
			outputHashes[i].Assign(outputHashForUs)
			outputs[i].Assign(evaler.Output)
			outputHashes[i] = outputHashForUs
			outputs[i] = evaler.Output
		}

		witness, err := frontend.NewWitness(&circuit, ecc.BLS12_381)
		if err != nil {
			return fmt.Errorf("couldn't create witness: %w", err)
		}

		for i, joinID := range *el.FinalJoinIDs {
			voter := *el.Joins[joinID]
			proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk(), &witness)
			proof, err := groth16.Prove(r1cs, voter.SumProvingKey.Pk(), witness)
			if err != nil {
				el.RUnlock()
				return fmt.Errorf("couldn't prove sum: %s", err)

M go.mod => go.mod +26 -5
@@ 1,17 1,38 @@
module tallyard.xyz

go 1.13
go 1.17

require (
	github.com/consensys/gnark v0.5.2
	github.com/consensys/gnark-crypto v0.5.3
	github.com/fxamacker/cbor/v2 v2.3.1 // indirect
	github.com/consensys/gnark v0.6.4
	github.com/consensys/gnark-crypto v0.6.1
	github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385
	github.com/kyoh86/xdg v1.2.0
	github.com/rivo/tview v0.0.0-20211202162923-2a6de950f73b
	github.com/sirupsen/logrus v1.8.1
	golang.org/x/crypto v0.0.0-20220312131142-6068a2e6cfdc
	golang.org/x/mod v0.5.1
	golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect
	maunium.net/go/mautrix v0.10.11
)

require (
	github.com/fxamacker/cbor/v2 v2.3.1 // indirect
	github.com/gdamore/encoding v1.0.0 // indirect
	github.com/gorilla/mux v1.8.0 // indirect
	github.com/gorilla/websocket v1.4.2 // indirect
	github.com/kr/text v0.2.0 // indirect
	github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
	github.com/mattn/go-runewidth v0.0.13 // indirect
	github.com/mmcloughlin/addchain v0.4.0 // indirect
	github.com/rivo/uniseg v0.2.0 // indirect
	github.com/tidwall/gjson v1.14.0 // indirect
	github.com/tidwall/match v1.1.1 // indirect
	github.com/tidwall/pretty v1.2.0 // indirect
	github.com/tidwall/sjson v1.2.4 // indirect
	github.com/x448/float16 v0.8.4 // indirect
	golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect
	golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect
	golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
	golang.org/x/text v0.3.7 // indirect
	gopkg.in/yaml.v2 v2.4.0 // indirect
	maunium.net/go/maulogger/v2 v2.3.2 // indirect
)

M go.sum => go.sum +12 -6
@@ 1,8 1,9 @@
github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI=
github.com/consensys/gnark v0.5.2 h1:/TTBStGJXkJqFVYFT7YnWmd0PedZlavUb7qOHO2UMEg=
github.com/consensys/gnark v0.5.2/go.mod h1:gaY1Ij1sp3TnLexb6y9y0KslzqVDvRg+XKldbXXK7ss=
github.com/consensys/gnark-crypto v0.5.3 h1:4xLFGZR3NWEH2zy+YzvzHicpToQR8FXFbfLNvpGB+rE=
github.com/consensys/gnark-crypto v0.5.3/go.mod h1:hOdPlWQV1gDLp7faZVeg8Y0iEPFaOUnCc4XeCCk96p0=
github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI=
github.com/consensys/gnark v0.6.4 h1:V+iSUwRQvTqug3z+Sb8ve5C9JWTU8bP29sU3HCmlARc=
github.com/consensys/gnark v0.6.4/go.mod h1:rJwNZk2xhK/V2yYlqBS9Y4FYvZ1347lWejsIr2HRVak=
github.com/consensys/gnark-crypto v0.6.1 h1:MuWaJyWzSw8wQUOfiZOlRwYjfweIj8dM/u2NN6m0O04=
github.com/consensys/gnark-crypto v0.6.1/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=


@@ 14,6 15,7 @@ github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo
github.com/gdamore/tcell/v2 v2.4.1-0.20210905002822-f057f0a857a1/go.mod h1:Az6Jt+M5idSED2YPGtwnfJV0kXohgdCBPmHGSYc1r04=
github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385 h1:O5oaOCRcXvNnsPikhB6xGd4a1bbfgcuFQCQgDB4tM7Y=
github.com/gdamore/tcell/v2 v2.4.1-0.20211227212015-3260e4ac4385/go.mod h1:I8YJFI9gzgl4dHi9UlRDZosCW+jYkDA37AXmXvL51w4=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=


@@ 21,8 23,9 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kyoh86/xdg v1.2.0 h1:CERuT/ShdTDj+A2UaX3hQ3mOV369+Sj+wyn2nIRIIkI=
github.com/kyoh86/xdg v1.2.0/go.mod h1:/mg8zwu1+qe76oTFUBnyS7rJzk7LLC0VGEzJyJ19DHs=
github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c=


@@ 32,6 35,9 @@ github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i
github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU=
github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.11/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY=
github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU=
github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/tview v0.0.0-20211202162923-2a6de950f73b h1:EMgbQ+bOHWkl0Ptano8M0yrzVZkxans+Vfv7ox/EtO8=

D math/hash.go => math/hash.go +0 -11
@@ 1,11 0,0 @@
package math

import "golang.org/x/crypto/sha3"

func HashSeedFromSeedParts(seedParts [][]byte) string {
	res := sha3.Sum256(seedParts[0])
	for _, part := range seedParts[1:] {
		res = sha3.Sum256(append(res[:], part...))
	}
	return string(res[:])
}

M math/lagrange_test.go => math/lagrange_test.go +1 -1
@@ 74,7 74,7 @@ func TestRandomPoly(t *testing.T) {
			seedParts[i] = seedPart[:]
		}

		poly := NewRandomPoly(ballot, HashSeedFromSeedParts(seedParts), inputs)
		poly := NewRandomPoly(ballot, inputs)
		points := make([]Point, n)
		for i := 0; i < n; i++ {
		genx:

M math/poly.go => math/poly.go +24 -18
@@ 3,16 3,17 @@ package math
import (
	"fmt"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc"
	"github.com/consensys/gnark/backend/groth16"
	"github.com/consensys/gnark/backend/witness"
	"github.com/consensys/gnark/frontend"
)

type Poly struct {
	Ballot   [][]byte      `json:"ballot"`
	Coeffs   []*fr.Element `json:"coeffs"`
	HashSeed string        `json:"hash_seed"`
	Inputs   []*fr.Element `json:"inputs"`

	cache *polyCache `json:"-"`


@@ 23,7 24,7 @@ type polyCache struct {
	constant     *fr.Element
	outputHashes map[*fr.Element]outputCache
	r1cs         frontend.CompiledConstraintSystem
	witness      EvalsCircuit
	witness      *witness.Witness
}

type outputCache struct {


@@ 33,7 34,7 @@ type outputCache struct {

// NewRandomPoly generates a random polynomial of degree numVoters-1, using
// `ballot' as the constant term
func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly {
func NewRandomPoly(ballot [][]byte, inputs []*fr.Element) *Poly {
	coeffs := make([]*fr.Element, len(inputs) - 1)

	var err error


@@ 44,7 45,7 @@ func NewRandomPoly(ballot [][]byte, hashSeed string, inputs []*fr.Element) *Poly
		}
	}

	return &Poly{ballot, coeffs, hashSeed, inputs, nil}
	return &Poly{ballot, coeffs, inputs, nil}
}

func (p *Poly) setupCache() {


@@ 75,7 76,7 @@ func (p *Poly) setupCache() {
	for _, input := range p.Inputs {
		output := p.eval(input)
		byts := output.Bytes()
		hash, err := mimc.Sum(p.HashSeed, byts[:])
		hash, err := mimc.Sum(byts[:])
		if err != nil {
			panic(err)
		}


@@ 85,33 86,38 @@ func (p *Poly) setupCache() {
	// compile R1CS
	{
		var err error
		p.cache.r1cs, err = EvalsCircuitR1CS(p.HashSeed, len(p.Inputs), len(p.Ballot))
		p.cache.r1cs, err = EvalsCircuitR1CS(len(p.Inputs), len(p.Ballot))
		if err != nil {
			panic(err)
		}
	}

	circuit := &EvalsCircuit{}
	// create witness
	{
		witness := &p.cache.witness
		// public
		witness.HashSeed = p.HashSeed
		witness.Inputs = make([]frontend.Variable, len(p.Inputs))
		witness.OutputHashes = make([]frontend.Variable, len(p.Inputs))
		circuit.Inputs = make([]frontend.Variable, len(p.Inputs))
		circuit.OutputHashes = make([]frontend.Variable, len(p.Inputs))
		for i, input := range p.Inputs {
			witness.Inputs[i].Assign(input)
			witness.OutputHashes[i].Assign(p.cache.outputHashes[input].hash)
			circuit.Inputs[i] = input
			circuit.OutputHashes[i] = p.cache.outputHashes[input].hash
		}
		// private
		witness.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
		circuit.BallotBits = make([]frontend.Variable, len(p.Ballot)*len(p.Ballot[0]))
		for i, bit := range p.cache.ballotBits {
			witness.BallotBits[i].Assign(bit)
			circuit.BallotBits[i] = bit
		}
		witness.Coeffs = make([]frontend.Variable, len(p.Coeffs))
		circuit.Coeffs = make([]frontend.Variable, len(p.Coeffs))
		for i, coeff := range p.Coeffs {
			witness.Coeffs[i].Assign(coeff)
			circuit.Coeffs[i] = coeff
		}
		witness.Constant.Assign(p.cache.constant)
		circuit.Constant = p.cache.constant
	}

	var err error
	p.cache.witness, err = frontend.NewWitness(circuit, ecc.BLS12_381)
	if err != nil {
		panic(fmt.Errorf("couldn't create witness: %w", err))
	}
}



@@ 120,7 126,7 @@ func (p *Poly) EvalAndProve(input *fr.Element, provingKey groth16.ProvingKey) (*
		p.setupCache()
	}

	proof, err := groth16.Prove(p.cache.r1cs, provingKey, &p.cache.witness)
	proof, err := groth16.Prove(p.cache.r1cs, provingKey, p.cache.witness)
	if err != nil {
		panic(err)
	}

M math/zk.go => math/zk.go +10 -14
@@ 10,7 10,6 @@ import (
)

type EvalsCircuit struct {
	HashSeed     string
	Inputs       []frontend.Variable `gnark:",public"`
	OutputHashes []frontend.Variable `gnark:",public"`



@@ 21,7 20,7 @@ type EvalsCircuit struct {
	Constant   frontend.Variable   `gnark:",secret"`
}

func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *EvalsCircuit) Define(api frontend.API) error {
	if len(circuit.Inputs) != len(circuit.OutputHashes) {
		return errors.New("len(circuit.Inputs) != len(circuit.OutputHashes)")
	}


@@ 44,7 43,7 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
		}

		// hash
		mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381, api)
		mimc, err := mimc.NewMiMC(api)
		if err != nil {
			return err
		}


@@ 55,9 54,9 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
	// prove the constant is valid (i.e. a string of bits separated by 7
	// zeros)
	// TODO: support more than 255 voters (limit from usage of byte)
	zero := api.Constant(0)
	shift := api.Constant(1 << 8)
	slot := api.Constant(1)
	zero := frontend.Variable(0)
	shift := frontend.Variable(1 << 8)
	slot := frontend.Variable(1)
	constructedConstant := zero
	for i := len(circuit.BallotBits)-1; i >= 0; i-- {
		bit := circuit.BallotBits[i]


@@ 70,9 69,8 @@ func (circuit *EvalsCircuit) Define(curveID ecc.ID, api frontend.API) error {
	return nil
}

func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
func EvalsCircuitR1CS(numVoters int, numCandidates int) (frontend.CompiledConstraintSystem, error) {
	var circuit EvalsCircuit
	circuit.HashSeed = hashSeed
	circuit.Inputs = make([]frontend.Variable, numVoters)
	circuit.OutputHashes = make([]frontend.Variable, numVoters)
	circuit.BallotBits = make([]frontend.Variable, numCandidates*numCandidates)


@@ 85,7 83,6 @@ func EvalsCircuitR1CS(hashSeed string, numVoters int, numCandidates int) (fronte
}

type SumCircuit struct {
	HashSeed     string
	HashSelects  []frontend.Variable `gnark:",public"`
	OutputHashes []frontend.Variable `gnark:",public"`
	Sum          frontend.Variable   `gnark:",public"`


@@ 93,7 90,7 @@ type SumCircuit struct {
	Outputs      []frontend.Variable `gnark:",secret"`
}

func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *SumCircuit) Define(api frontend.API) error {
	if len(circuit.HashSelects) != len(circuit.OutputHashes) {
		return errors.New("len(circuit.HashSelects) != len(circuit.OutputHashes)")
	}


@@ 102,9 99,9 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error {
	}

	// prove the secret outputs hash to public hashes
	zero := api.Constant(0)
	zero := frontend.Variable(0)

	mimc, err := mimc.NewMiMC(circuit.HashSeed, ecc.BLS12_381, api)
	mimc, err := mimc.NewMiMC(api)
	if err != nil {
		return err
	}


@@ 131,9 128,8 @@ func (circuit *SumCircuit) Define(curveID ecc.ID, api frontend.API) error {
	return nil
}

func SumCircuitR1CS(hashSeed string, numVoters int) (frontend.CompiledConstraintSystem, error) {
func SumCircuitR1CS(numVoters int) (frontend.CompiledConstraintSystem, error) {
	var circuit SumCircuit
	circuit.HashSeed = hashSeed
	circuit.HashSelects = make([]frontend.Variable, numVoters)
	circuit.OutputHashes = make([]frontend.Variable, numVoters)
	circuit.Outputs = make([]frontend.Variable, numVoters)

M version.go => version.go +1 -1
@@ 1,3 1,3 @@
package tallyard

const Version string = "v0.4.5"
const Version string = "v0.5.0"