~edwargix/tallyard

a313d735a011db11ea258d800d9a630737c0d169 — David Florness 4 years ago a840e44
Try to mitigate against replay attacks and refactor event retrieval
M cmd/tallyard/main.go => cmd/tallyard/main.go +51 -38
@@ 5,9 5,9 @@ import (
	"os"

	"github.com/kyoh86/xdg"
	log "github.com/sirupsen/logrus"
	"maunium.net/go/mautrix"
	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"

	"tallyard.xyz/election"
	"tallyard.xyz/math"


@@ 15,41 15,50 @@ import (
	"tallyard.xyz/ui"
)

func electionFilter(localUserID id.UserID) *mautrix.Filter {
	return &mautrix.Filter{
		AccountData: mautrix.FilterPart{
var electionFilter = &mautrix.Filter{
	AccountData: mautrix.FilterPart{
		NotTypes: []event.Type{event.NewEventType("*")},
	},
	Presence: mautrix.FilterPart{
		NotTypes: []event.Type{event.NewEventType("*")},
	},
	Room: mautrix.RoomFilter{
		Ephemeral: mautrix.FilterPart{
			NotTypes: []event.Type{event.NewEventType("*")},
		},
		Presence: mautrix.FilterPart{
			NotTypes: []event.Type{event.NewEventType("*")},
		State: mautrix.FilterPart{
			Types: []event.Type{event.StateRoomName},
		},
		Room: mautrix.RoomFilter{
			Ephemeral: mautrix.FilterPart{
				NotTypes: []event.Type{event.NewEventType("*")},
			},
			State: mautrix.FilterPart{
				Types: []event.Type{event.StateRoomName},
			},
			Timeline: mautrix.FilterPart{
				LazyLoadMembers: true,
				Limit: 50,
				NotSenders: []id.UserID{localUserID},
				Types: []event.Type{
					election.CreateElectionMessage,
					election.JoinElectionMessage,
					election.StartElectionMessage,
					election.EvalMessage,
					election.SumMessage,
					election.ResultMessage,
				},
		Timeline: mautrix.FilterPart{
			LazyLoadMembers: true,
			Types: []event.Type{
				election.CreateElectionMessage,
				election.JoinElectionMessage,
				election.StartElectionMessage,
				election.EvalsMessage,
				election.SumMessage,
				election.ResultMessage,
			},
		},
	}
	},
}

func main() {
	os.MkdirAll(xdg.DataHome() + "/tallyard", 0700)

	// You could set this to any `io.Writer` such as a file
	file, err := os.OpenFile(xdg.DataHome() + "/tallyard/tallyard.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
	if err == nil {
		log.SetLevel(log.DebugLevel)
		log.SetOutput(file)
	} else {
		log.Errorf("failed to open logging file; using default stderr: %s", err)
	}
	log.Info("tallyard starting...")
	defer func() {
		log.Info("tallyard exiting...")
	}()

	authInfo, err := matrix.GetAuthInfo()
	if err != nil {
		fmt.Fprintln(os.Stderr, err)


@@ 68,18 77,16 @@ func main() {
	}
	client.Store = matrix.NewTallyardStore(elections)
	defer func() {
		err := elections.Save()
		if err != nil {
			panic(err)
		}
		elections.Save()
	}()

	syncer := client.Syncer.(*mautrix.DefaultSyncer)
	syncer.OnEvent(debugEventHook)
	syncer.OnEvent(client.Store.(*matrix.TallyardStore).UpdateState)
	election.SetupEventHooks(client, syncer, elections)
	elections.SetupEventHooks(client, syncer)

	go func() {
		res, err := client.CreateFilter(electionFilter(client.UserID))
		res, err := client.CreateFilter(electionFilter)
		if err != nil {
			panic(err)
		}


@@ 111,11 118,11 @@ func main() {
	}

	// wait for election to start
	err = ui.ElectionWaitTUI(client, el)
	err = ui.ElectionWaitTUI(client, el, elections.EventStore)
	if err != nil {
		panic(err)
	}
	if el.StartEvt == nil {
	if el.StartID == nil {
		// election never started; user likely hit C-c
		return
	}


@@ 130,15 137,15 @@ func main() {
	}

	// set random poly with ballot
	el.LocalVoter.Poly = math.NewRandomPoly(uint(len(*el.FinalVoters)-1), 1024, *el.LocalVoter.Ballot)
	el.LocalVoter.Poly = math.NewRandomPoly(uint(len(*el.FinalJoinIDs)-1), 1024, *el.LocalVoter.Ballot)
	el.Save()

	// wait for other voters to finish
	el.WaitForVoters(client)
	el.WaitForJoins(client)

	// send eval if we need to (if LocalVoter.Eval is set, we've already
	// send evals if we need to (if LocalVoter.Evals is set, we've already
	// sent the event)
	if el.LocalVoter.Eval == nil {
	if el.LocalVoter.EvalsID == nil {
		err = el.SendEvals(client)
		if err != nil {
			panic(err)


@@ 156,3 163,9 @@ func main() {

	el.GetSums(client)
}

func debugEventHook(_ mautrix.EventSource, evt *event.Event) {
	log.Debugf("<%[1]s> %[4]s (%[2]s/%[3]s)\n",
		evt.Sender, evt.Type.String(), evt.ID,
		evt.Content.AsMessage().Body)
}

M election/election.go => election/election.go +16 -12
@@ 2,6 2,7 @@ package election

import (
	"encoding/json"
	"fmt"
	"sync"

	"maunium.net/go/mautrix/event"


@@ 11,31 12,34 @@ import (
type Election struct {
	sync.RWMutex

	Candidates  []Candidate           `json:"candidates"`
	CreateEvt   event.Event           `json:"create_evt"`
	FinalVoters *[]id.EventID         `json:"final_voters,omitempty"`
	Joins       map[id.EventID]*Voter `json:"joins"`
	LocalVoter  *LocalVoter           `json:"local_voter,omitempty"`
	Save        func() error          `json:"-"`
	StartEvt    *event.Event          `json:"start_evt,omitempty"`
	Title       string                `json:"title"`
	Candidates   []Candidate           `json:"candidates"`
	CreateEvt    event.Event           `json:"create_evt"`
	FinalJoinIDs *[]id.EventID         `json:"final_voters,omitempty"`
	Joins        map[id.EventID]*Voter `json:"joins"`
	LocalVoter   *LocalVoter           `json:"local_voter,omitempty"`
	RoomID       id.RoomID             `json:"room_id"`
	StartID      *id.EventID           `json:"start_id,omitempty"`
	Title        string                `json:"title"`

	// Save election to disk.  Set by the containing ElectionsMap at runtime
	Save         func()                `json:"-"`
}

func NewElection(candidates []Candidate, createEvt event.Event, title string) *Election {
func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election {
	return &Election{
		Candidates: candidates,
		CreateEvt:  createEvt,
		Joins:      make(map[id.EventID]*Voter),
		RoomID:     roomID,
		Title:      title,
	}
}

func (el *Election) UnmarshalJSON(b []byte) error {
	type Alias Election // create alias to prevent endless loop
	tmp := (*Alias)(el)
	err := json.Unmarshal(b, tmp)
	err := json.Unmarshal(b, (*Alias)(el))
	if err != nil {
		return err
		return fmt.Errorf("couldn't unmarshal election: %s", err)
	}
	// ensure these are the same pointer
	if el.LocalVoter != nil {

A election/event.go => election/event.go +208 -0
@@ 0,0 1,208 @@
package election

import (
	"encoding/json"
	"fmt"
	"sync"

	log "github.com/sirupsen/logrus"
	"maunium.net/go/mautrix"
	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"
)

// Used to store all relevant election Events that we've tried processing.  A
// value of nil for a given event ID key in the Events map means that the
// corresponding event was processed unsuccessfully.  Likewise, a non-nil value
// indicates that the event was processed successfully.
//
// This EventStore also uses the provided Client to fetch missing events and
// pipe them through the corresponding EventHandlers when needed.
type EventStore struct {
	sync.RWMutex
	Events        map[id.EventID]*event.Event           `json:"events"`

	Client        *mautrix.Client                       `json:"-"`
	EventHandlers map[event.Type]func(*event.Event)bool `json:"-"`
}

func NewEventStore(client *mautrix.Client, eventHandlers map[event.Type]func(*event.Event)bool) *EventStore {
	return &EventStore{
		Events:        make(map[id.EventID]*event.Event),

		Client:        client,
		EventHandlers: eventHandlers,
	}
}

func (store *EventStore) UnmarshalJSON(b []byte) error {
	type Alias EventStore
	err := json.Unmarshal(b, (*Alias)(store))
	if err != nil {
		return err
	}
	for _, event := range store.Events {
		if event == nil {
			continue
		}
		// event.Type.Class is "UnkownEventType" after unmarshlling
		switch event.Type.Type {
		case CreateElectionMessage.Type:
			event.Type = CreateElectionMessage
		case JoinElectionMessage.Type:
			event.Type = JoinElectionMessage
		case StartElectionMessage.Type:
			event.Type = StartElectionMessage
		case EvalsMessage.Type:
			event.Type = EvalsMessage
		case SumMessage.Type:
			event.Type = SumMessage
		case ResultMessage.Type:
			event.Type = ResultMessage
		}
		if err = event.Content.ParseRaw(event.Type); err != nil {
			return err
		}
	}
	return nil
}

type CreateEvent struct {
	*event.Event
	*CreateElectionContent
}

type JoinEvent struct {
	*event.Event
	*JoinElectionContent
}

type StartEvent struct {
	*event.Event
	*StartElectionContent
}

type EvalsEvent struct {
	*event.Event
	*EvalsMessageContent
}

type SumEvent struct {
	*event.Event
	*SumMessageContent
}

func (store *EventStore) GetCreateEvent(roomID id.RoomID, createID id.EventID) *CreateEvent {
	evt, err := store.getAndHandleEvent(roomID, createID, CreateElectionMessage)
	if err != nil {
		log.Warnf("an error occurred getting create event '%s': %s", createID, err)
		return nil
	}
	if evt == nil {
		return nil
	}
	return &CreateEvent{
		evt,
		evt.Content.Parsed.(*CreateElectionContent),
	}
}

func (store *EventStore) GetJoinEvent(roomID id.RoomID, joinID id.EventID) *JoinEvent {
	evt, err := store.getAndHandleEvent(roomID, joinID, JoinElectionMessage)
	if err != nil {
		log.Warnf("an error occurred getting join event '%s': %s", joinID, err)
		return nil
	}
	if evt == nil {
		return nil
	}
	return &JoinEvent{
		evt,
		evt.Content.Parsed.(*JoinElectionContent),
	}
}

func (store *EventStore) GetStartEvent(roomID id.RoomID, startID id.EventID) *StartEvent {
	evt, err := store.getAndHandleEvent(roomID, startID, StartElectionMessage)
	if err != nil {
		log.Warnf("an error occurred getting start event '%s': %s", startID, err)
		return nil
	}
	if evt == nil {
		return nil
	}
	return &StartEvent{
		evt,
		evt.Content.Parsed.(*StartElectionContent),
	}
}

func (store *EventStore) GetEvalsEvent(roomID id.RoomID, evalsID id.EventID) *EvalsEvent {
	evt, err := store.getAndHandleEvent(roomID, evalsID, EvalsMessage)
	if err != nil {
		log.Warnf("an error occurred getting evals event '%s': %s", evalsID, err)
		return nil
	}
	if evt == nil {
		return nil
	}
	return &EvalsEvent{
		evt,
		evt.Content.Parsed.(*EvalsMessageContent),
	}
}

func (store *EventStore) GetSumEvent(roomID id.RoomID, sumID id.EventID) *SumEvent {
	evt, err := store.getAndHandleEvent(roomID, sumID, EvalsMessage)
	if err != nil {
		log.Warnf("an error occurred getting sum event '%s': %s", sumID, err)
		return nil
	}
	if evt == nil {
		return nil
	}
	return &SumEvent{
		evt,
		evt.Content.Parsed.(*SumMessageContent),
	}
}

func (store *EventStore) getAndHandleEvent(roomID id.RoomID, eventID id.EventID, eventType event.Type) (*event.Event, error) {
	// see if we've handled this event before
	store.RLock()
	evt, exists := store.Events[eventID]
	store.RUnlock()
	if exists {
		// This event was seen before.  If evt != nil, it was handled
		// successfully and vice versa.
		return evt, nil
	}
	// we've never seen this event before
	evt, err := store.fetchEvent(roomID, eventID)
	if err != nil {
		return nil, fmt.Errorf("couldn't fetch %s event '%s': %s", eventType.Type, eventID, err)
	}
	if evt == nil {
		return nil, fmt.Errorf("fetched %s event '%s' is nil", eventType.Type, eventID)
	}
	// TODO check version here rather than in msg?
	err = evt.Content.ParseRaw(eventType)
	if err != nil {
		return nil, fmt.Errorf("couldn't parse %s event '%s' content: %s", eventType.Type, eventID, err)
	}
	if !store.EventHandlers[eventType](evt) {
		return nil, fmt.Errorf("couldn't handle %s event '%s'", eventType.Type, eventID)
	}
	return evt, nil
}

func (store *EventStore) fetchEvent(roomID id.RoomID, eventID id.EventID) (*event.Event, error) {
	evt, err := store.Client.GetEvent(roomID, eventID)
	if err != nil {
		return nil, err
	}
	if evt.Unsigned.RedactedBecause != nil {
		return nil, fmt.Errorf("event %s was redacted", eventID)
	}
	return evt, nil
}

M election/map.go => election/map.go +53 -65
@@ 10,48 10,43 @@ import (
	"time"

	"github.com/kyoh86/xdg"
	log "github.com/sirupsen/logrus"
	"maunium.net/go/mautrix"
	"maunium.net/go/mautrix/id"
)

type ElectionsMap struct {
	sync.RWMutex

	// The version of the elections map.  If the version in the file doesn't
	// match the current version then we can give a clear error message
	// telling the user to update their file to the current version.  The
	// version in the file should only be different if there has been a
	// breaking change to the elections map format.
	Version int `json:"version"`

	Version    int                         `json:"version"`
	// Maps election create event IDs to the corresponding election.
	M map[id.EventID]*Election `json:"elections"`

	Elections  map[id.EventID]*Election    `json:"elections"`
	// See EventStore doc for explanation
	EventStore *EventStore                 `json:"event_store,omitempty"`
	// State store
	NextBatch  string                      `json:"next_batch"`
	Rooms      map[id.RoomID]*mautrix.Room `json:"rooms"`
	UserID     id.UserID                   `json:"user_id"`
	// Maps room to a list of the room's elections, reverse sorted by
	// CreationTimestamp (i.e. newest to oldest).  Here for convenience.
	L map[id.RoomID][]*Election `json:"-"`
	// The latest time L was updated
	Ltime time.Time `json:"-"`

	// maps join event ID to voter. This is a convenience in scenarios when
	// we don't know the exact election
	Joins map[id.EventID]*Voter `json:"-"`

	// needed by the state store
	NextBatch string                      `json:"next_batch"`
	Rooms     map[id.RoomID]*mautrix.Room `json:"rooms"`
	UserID    id.UserID                   `json:"userid"`
	// CreationTimestamp (i.e. newest to oldest).  Here for convenience for
	// GUIs.  Ltime represents the latest time L was updated.
	L          map[id.RoomID][]*Election   `json:"-"`
	Ltime      time.Time                   `json:"-"`
}

const electionsMapVersion = 3
const electionsMapVersion = 4

var electionsFname = xdg.DataHome() + "/tallyard/elections.json"

func GetElectionsMap(userID id.UserID) (em *ElectionsMap, err error) {
	em = &ElectionsMap{}
	if _, err = os.Stat(electionsFname); os.IsNotExist(err) {
func GetElectionsMap(userID id.UserID) (*ElectionsMap, error) {
	if _, err := os.Stat(electionsFname); os.IsNotExist(err) {
		return newElectionsMap(userID), nil
	}
	em := &ElectionsMap{}
	jsonBytes, err := ioutil.ReadFile(electionsFname)
	if err != nil {
		return nil, fmt.Errorf("error reading data file %s: %s", electionsFname, err)


@@ 68,19 63,15 @@ func GetElectionsMap(userID id.UserID) (em *ElectionsMap, err error) {
		// TODO support multiple users?
		return nil, fmt.Errorf("user IDs don't match")
	}
	// set runtime attributes for elections
	em.L = make(map[id.RoomID][]*Election, 0)
	em.Ltime = time.Now()
	for createEventId, el := range em.M {
		em.insort(createEventId, el)
	storedElections := em.Elections
	em.Elections = make(map[id.EventID]*Election, len(storedElections))
	for _, el := range storedElections {
		em.AddElection(el)
	}
	em.Joins = make(map[id.EventID]*Voter)
	for _, el := range em.M {
		el.Save = em.Save
		for joinEventId, voter := range el.Joins {
			em.Joins[joinEventId] = voter
		}
	}
	return
	return em, nil
}

func (em *ElectionsMap) UnmarshalJSON(b []byte) error {


@@ 92,7 83,9 @@ func (em *ElectionsMap) UnmarshalJSON(b []byte) error {
	for _, room := range em.Rooms {
		for eventType, events := range room.State {
			for _, evt := range events {
				evt.Content.ParseRaw(eventType)
				if err = evt.Content.ParseRaw(eventType); err != nil {
					return err
				}
			}
		}
	}


@@ 101,68 94,57 @@ func (em *ElectionsMap) UnmarshalJSON(b []byte) error {

func newElectionsMap(userID id.UserID) *ElectionsMap {
	return &ElectionsMap{
		Version: electionsMapVersion,
		M:       make(map[id.EventID]*Election),
		L:       make(map[id.RoomID][]*Election),
		Ltime:   time.Now(),
		Joins:   make(map[id.EventID]*Voter),
		Rooms:   make(map[id.RoomID]*mautrix.Room),
		UserID:  userID,
		Version:   electionsMapVersion,
		Elections: make(map[id.EventID]*Election),
		L:         make(map[id.RoomID][]*Election),
		Ltime:     time.Now(),
		Rooms:     make(map[id.RoomID]*mautrix.Room),
		UserID:    userID,
	}
}

func (em *ElectionsMap) Save() error {
func (em *ElectionsMap) Save() {
	em.RLock()
	defer em.RUnlock()
	for _, el := range em.M {
	for _, el := range em.Elections {
		el.RLock()
		defer el.RUnlock()
	}
	jsonBytes, err := json.Marshal(em)
	if err != nil {
		return fmt.Errorf("couldn't marshal elections: %s", err)
		log.Errorf("couldn't marshal elections: %s", err)
	}
	err = ioutil.WriteFile(electionsFname, jsonBytes, 0600)
	if err != nil {
		return fmt.Errorf("couldn't save elections: %s", err)
		log.Errorf("couldn't save elections: %s", err)
	}
	return nil
}

func (em *ElectionsMap) Get(createEventID id.EventID) *Election {
	em.RLock()
	defer em.RUnlock()
	return em.M[createEventID]
	log.Info("saved elections map")
}

func (em *ElectionsMap) GetOk(createEventID id.EventID) (*Election, bool) {
func (em *ElectionsMap) GetElection(createID id.EventID) *Election {
	em.RLock()
	defer em.RUnlock()
	el, ok := em.M[createEventID]
	return el, ok
	return em.Elections[createID]
}

func (em *ElectionsMap) SetIfNotExists(createEventID id.EventID, el *Election) {
func (em *ElectionsMap) AddElection(el *Election) {
	em.Lock()
	defer em.Save()
	defer em.Unlock()
	_, exists := em.M[createEventID]
	_, exists := em.Elections[el.CreateEvt.ID]
	if exists {
		log.Warnf("election %s was already added to elections map", el.CreateEvt.ID)
		return
	}
	em.set(createEventID, el)
}

func (em *ElectionsMap) set(createEventID id.EventID, el *Election) {
	el.Lock()
	defer el.Unlock()
	em.Elections[el.CreateEvt.ID] = el
	el.Save = em.Save
	em.M[createEventID] = el
	em.insort(createEventID, el)
	em.insortElection(el.CreateEvt.ID, el)
}

func (em *ElectionsMap) insort(createEventID id.EventID, el *Election) {
	list := em.L[el.CreateEvt.RoomID]
func (em *ElectionsMap) insortElection(createID id.EventID, el *Election) {
	list := em.L[el.RoomID]
	i := sort.Search(len(list), func(i int) bool {
		return list[i].CreateEvt.Timestamp < el.CreateEvt.Timestamp
	})


@@ 170,6 152,12 @@ func (em *ElectionsMap) insort(createEventID id.EventID, el *Election) {
	copy(newList[:i], list[:i])
	newList[i] = el
	copy(newList[i+1:], list[i:])
	em.L[el.CreateEvt.RoomID] = newList
	em.L[el.RoomID] = newList
	em.Ltime = time.Now()
}

func (em *ElectionsMap) SetEventStore(eventStore *EventStore) {
	em.Lock()
	defer em.Unlock()
	em.EventStore = eventStore
}

M election/msg.go => election/msg.go +424 -172
@@ 2,11 2,13 @@ package election

import (
	"encoding/base64"
	"fmt"
	"math/big"
	"reflect"

	log "github.com/sirupsen/logrus"
	"golang.org/x/crypto/nacl/box"
	"golang.org/x/mod/semver"
	"maunium.net/go/mautrix"
	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"


@@ 29,9 31,9 @@ var (
		Type:  "xyz.tallyard.start",
		Class: event.MessageEventType,
	}
	// indicate's user's evaluation of their polynomial using others' inputs
	EvalMessage = event.Type{
		Type:  "xyz.tallyard.eval",
	// indicate's user's evaluations of their polynomial using others' inputs
	EvalsMessage = event.Type{
		Type:  "xyz.tallyard.evals",
		Class: event.MessageEventType,
	}
	// indicates a user's individual summation


@@ 47,306 49,556 @@ var (
)

type CreateElectionContent struct {
	Version string `json:"version"`

	Candidates []Candidate `json:"candidates"`
	Title      string      `json:"title"`
	Version    string      `json:"version"`
}

type JoinElectionContent struct {
	CreateEventId id.EventID `json:"create_event_id"`
	Input         string     `json:"input"`
	NaclPublicKey string     `json:"nacl_public_key"`
	Version string `json:"version"`

	CreateID id.EventID `json:"create_id"`
	Input    string     `json:"input"`
	PubKey   string     `json:"pub_key"`
}

type StartElectionContent struct {
	CreateEventId id.EventID   `json:"create_event_id"`
	VoterJoinIds  []id.EventID `json:"voter_join_ids"`
	Version string `json:"version"`

	CreateID id.EventID   `json:"create_id"`
	JoinIDs  []id.EventID `json:"join_ids"`
}

type EvalMessageContent struct {
	JoinEventId id.EventID            `json:"join_event_id"`
	Outputs     map[id.EventID]string `json:"outputs"`
type EvalsMessageContent struct {
	Version string `json:"version"`

	Evals   map[id.EventID]string `json:"evals"`
	JoinID  id.EventID            `json:"join_id"`
	StartID id.EventID            `json:"start_id"`
}

type SumMessageContent struct {
	JoinEventId id.EventID `json:"join_event_id"`
	Sum         string     `json:"sum"`
	Version string `json:"version"`

	EvalsIDs []id.EventID `json:"evals_ids"`
	JoinID   id.EventID   `json:"join_id"`
	Sum      string       `json:"sum"`
}

type ResultMessageContent struct {
	JoinEventId id.EventID `json:"join_event_id"`
	Result      string     `json:"result"`
	Version string `json:"version"`

	JoinID id.EventID   `json:"join_id"`
	Result string       `json:"result"`
	SumIDs []id.EventID `json:"sums"`
}

func init() {
	event.TypeMap[CreateElectionMessage] = reflect.TypeOf(CreateElectionContent{})
	event.TypeMap[JoinElectionMessage]   = reflect.TypeOf(JoinElectionContent{})
	event.TypeMap[StartElectionMessage]  = reflect.TypeOf(StartElectionContent{})
	event.TypeMap[EvalMessage]           = reflect.TypeOf(EvalMessageContent{})
	event.TypeMap[EvalsMessage]          = reflect.TypeOf(EvalsMessageContent{})
	event.TypeMap[SumMessage]            = reflect.TypeOf(SumMessageContent{})
	event.TypeMap[ResultMessage]         = reflect.TypeOf(ResultMessageContent{})
}

func SetupEventHooks(client *mautrix.Client, syncer mautrix.ExtensibleSyncer, elections *ElectionsMap) {
	wrapper := func(f func(*event.Event)) func(mautrix.EventSource, *event.Event) {
		return func(source mautrix.EventSource, evt *event.Event) {
			log.Debugf("%[5]d: <%[1]s> %[4]s (%[2]s/%[3]s)\n",
				evt.Sender, evt.Type.String(), evt.ID,
				evt.Content.AsMessage().Body, source)
func (elections *ElectionsMap) SetupEventHooks(client *mautrix.Client, syncer mautrix.ExtensibleSyncer) {
	elections.Lock()
	defer elections.Unlock()

	eventHandlers := make(map[event.Type]func(*event.Event)bool)
	eventStore := elections.EventStore

	if eventStore == nil {
		eventStore = NewEventStore(client, eventHandlers)
	} else {
		eventStore.Client = client
		eventStore.EventHandlers = eventHandlers
	}

	wrapper := func(f func(*event.Event) bool) func(*event.Event) bool {
		return func(evt *event.Event) (success bool) {
			if evt.Unsigned.RedactedBecause != nil {
				log.Debugf("%s redacted\n", evt.ID.String())
				log.Debugf("event %s was redacted", evt.ID)
				return
			}
			f(evt)
			eventStore.RLock()
			handledEvent, exists := eventStore.Events[evt.ID]
			eventStore.RUnlock()
			if exists {
				log.Debugf("event %s was already handled", evt.ID)
				return handledEvent != nil
			}
			success = f(evt)
			eventStore.Lock()
			// see EventStore doc for success explanation
			if success {
				eventStore.Events[evt.ID] = evt
			} else {
				eventStore.Events[evt.ID] = nil
			}
			eventStore.Unlock()
			return
		}
	}
	syncer.OnEventType(CreateElectionMessage, wrapper(func(evt *event.Event) {
		OnCreateElectionMessage(evt, elections)
	}))
	syncer.OnEventType(JoinElectionMessage, wrapper(func(evt *event.Event) {
		OnJoinElectionMessage(client, evt, elections)
	}))
	syncer.OnEventType(StartElectionMessage, wrapper(func(evt *event.Event) {
		OnStartElectionMessage(client, evt, elections)
	}))
	syncer.OnEventType(EvalMessage, wrapper(func(evt *event.Event) {
		OnEvalMessage(client, evt, elections)
	}))
	syncer.OnEventType(SumMessage, wrapper(func(evt *event.Event) {
		OnSumMessage(client, evt, elections)
	}))
	syncer.OnEventType(ResultMessage, wrapper(func(evt *event.Event) {
		OnResultMessage(client, evt, elections)
	}))

	eventHandlers[CreateElectionMessage] = wrapper(elections.onCreateElectionMessage)
	eventHandlers[JoinElectionMessage]   = wrapper(elections.onJoinElectionMessage)
	eventHandlers[StartElectionMessage]  = wrapper(elections.onStartElectionMessage)
	eventHandlers[EvalsMessage]          = wrapper(elections.onEvalsMessage)
	eventHandlers[SumMessage]            = wrapper(elections.onSumMessage)
	eventHandlers[ResultMessage]         = wrapper(elections.onResultMessage)

	for eventType, handler := range eventHandlers {
		func(handler func(*event.Event) bool) {
			syncer.OnEventType(eventType, func(_ mautrix.EventSource, evt *event.Event) {
				handler(evt)
			})
		}(handler)
	}

	elections.EventStore = eventStore
}

func OnCreateElectionMessage(evt *event.Event, elections *ElectionsMap) {
	// TODO: check version
type logfFunc func(format string, args ...interface{})
// error: short-circuit bug that happened locally;
// warn: short-circuit bug that happened remotely;
// debug: short-circuit info
func logFuncs(baseFormat string) (error, warn, debug logfFunc) {
	return func(reason string, args ...interface{}) {
		log.Errorf(baseFormat, fmt.Sprintf(reason, args...))
	}, func(reason string, args ...interface{}) {
		log.Warnf(baseFormat, fmt.Sprintf(reason, args...))
	}, func(reason string, args ...interface{}) {
		log.Debugf(baseFormat, fmt.Sprintf(reason, args...))
	}
}

func incompatibleVersion(version string) bool {
	return semver.Compare(
		semver.MajorMinor(version),
		semver.MajorMinor(Version),
	) != 0
}

func (elections *ElectionsMap) onCreateElectionMessage(evt *event.Event) (success bool) {
	_, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's create msg (%s) since %s", evt.Sender, evt.ID, "%s"))

	content, ok := evt.Content.Parsed.(*CreateElectionContent)
	if !ok {
		log.Warnf("ignoring %s's create since we couldn't cast message content to CreateElectionContent", evt.Sender)
		warnf("we couldn't cast message content to CreateElectionContent", evt.Sender)
		return
	}

	if incompatibleVersion(content.Version) {
		debugf("the version is incompatible")
		return
	}
	elections.SetIfNotExists(evt.ID, NewElection(content.Candidates, *evt, content.Title))

	elections.AddElection(NewElection(
		content.Candidates,
		*evt,
		evt.RoomID,
		content.Title,
	))

	return true
}

func OnJoinElectionMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) {
	content, ok := evt.Content.Parsed.(*JoinElectionContent)
	if !ok {
		log.Warnf("ignoring %s's join msg since we couldn't cast message content to JoinElectionContent", evt.Sender)
func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success bool) {
	errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's join msg (%s) since %s", evt.Sender, evt.ID, "%s"))

	content, exists := evt.Content.Parsed.(*JoinElectionContent)
	if !exists {
		warnf("we couldn't cast message content to JoinElectionContent")
		return
	}
	el := getElection(client, evt.RoomID, content.CreateEventId, elections)
	if el == nil {
		log.Warnf("ignoring %s's join msg since the election doesn't exist", evt.Sender)

	if incompatibleVersion(content.Version) {
		debugf("the version is incompatible")
		return
	}
	el.Lock()
	defer el.Unlock()
	_, voterExists := el.Joins[evt.ID]
	if voterExists {
		log.Debugf("ignoring %s's join msg since we already have their info", evt.Sender)

	createEvt := elections.EventStore.GetCreateEvent(evt.RoomID, content.CreateID)
	if createEvt == nil {
		debugf("we couldn't get the create event, %s", content.CreateID)
		return
	}

	bytes, err := base64.StdEncoding.DecodeString(content.Input)
	if err != nil {
		log.Warnf("ignoring %s's join msg since we couldn't decode their input", evt.Sender)
		warnf("we couldn't decode their input: %s", err)
		return
	}
	input := new(big.Int).SetBytes(bytes)
	bytes, err = base64.StdEncoding.DecodeString(content.NaclPublicKey)

	bytes, err = base64.StdEncoding.DecodeString(content.PubKey)
	if err != nil {
		log.Warnf("ignoring %s's join msg since we couldn't decode their public key: %s", evt.Sender, err)
		warnf("we couldn't decode their public key: %s", err)
		return
	}
	var pubKey [32]byte
	copy(pubKey[:], bytes)
	voter := NewVoter(evt.Sender, input, &pubKey, evt)
	el.Joins[evt.ID] = voter
	elections.Joins[evt.ID] = voter

	el := elections.GetElection(createEvt.ID)
	if el == nil {
		// should never happen because we retrieved the create event
		// above
		errorf("election %s doesn't exist", createEvt.ID)
		return
	}

	el.Lock()
	defer el.Save()
	defer el.Unlock()

	el.Joins[evt.ID] = NewVoter(input, evt, &pubKey)

	return true
}

func OnStartElectionMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) {
func (elections *ElectionsMap) onStartElectionMessage(evt *event.Event) (success bool) {
	errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's start msg (%s) since %s", evt.Sender, evt.ID, "%s"))

	content, ok := evt.Content.Parsed.(*StartElectionContent)
	if !ok {
		log.Warnf("ignoring %s's start msg since we couldn't cast message content to StartElectionContent", evt.Sender)
		warnf("we couldn't cast message content to StartElectionContent")
		return
	}

	if incompatibleVersion(content.Version) {
		debugf("the version is incompatible")
		return
	}

	createEvt := elections.EventStore.GetCreateEvent(evt.RoomID, content.CreateID)
	if createEvt == nil {
		debugf("we couldn't get the create event, %s", content.CreateID)
		return
	}
	el := getElection(client, evt.RoomID, content.CreateEventId, elections)

	for _, joinID := range content.JoinIDs {
		joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, joinID)
		if joinEvt == nil {
			debugf("we couldn't get the join event %s", joinID)
			return
		}
		if joinEvt.CreateID != content.CreateID {
			warnf("the join event %s didn't have the same create ID, %s, as the start event did",
				joinID, content.CreateID)
			return
		}
	}

	if createEvt.Sender != evt.Sender {
		warnf("they didn't create the election")
		return
	}

	el := elections.GetElection(createEvt.ID)
	if el == nil {
		log.Warnf("ignoring %s's start msg since the election doesn't exist", evt.Sender)
		// should never happen because we retrieved the craete event
		// above
		errorf("election %s doesn't exist", createEvt.ID)
		return
	}

	// election should exist since we were able to getCreateEvent
	el.Lock()
	defer el.Save()
	defer el.Unlock()
	if evt.Sender != el.CreateEvt.Sender {
		log.Warnf("ignoring %s's start msg since they didn't start the election", evt.Sender)
		return
	}
	// TODO we should probably just bail when there are multiple start messages
	if el.StartEvt != nil && el.StartEvt.Timestamp < evt.Timestamp {
		log.Warnf("ignoring %s's start msg since the election's already been started", evt.Sender)

	if el.StartID != nil {
		warnf("the election's already been started")
		return
	}
	el.StartEvt = evt
	el.FinalVoters = &content.VoterJoinIds
	// TODO check election voters

	el.StartID = &evt.ID
	el.FinalJoinIDs = &content.JoinIDs

	return true
}

func OnEvalMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) {
	content, ok := evt.Content.Parsed.(*EvalMessageContent)
func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) {
	errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's evals msg (%s) since %s", evt.Sender, evt.ID, "%s"))

	content, ok := evt.Content.Parsed.(*EvalsMessageContent)
	if !ok {
		log.Warn("ignoring %s's eval msg since we couldn't cast message content to EvalMessageContent", evt.Sender)
		warnf("we couldn't cast message content to EvalMessageContent")
		return
	}
	voter := getVoter(client, evt.RoomID, content.JoinEventId, elections)
	if voter == nil {
		log.Warnf("ignoring %s's eval msg since voter doesn't exist", evt.Sender)

	if incompatibleVersion(content.Version) {
		debugf("the version is incompatible")
		return
	}

	startEvt := elections.EventStore.GetStartEvent(evt.RoomID, content.StartID)
	if startEvt == nil {
		debugf("we couldn't get the start event, %s", content.StartID)
		return
	}

	joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, content.JoinID)
	if joinEvt == nil {
		debugf("we couldn't get the join event, %s", content.JoinID)
		return
	}

	// ensure keys of Evals are in JoinIDs of start event
	if len(content.Evals) != len(startEvt.JoinIDs) {
		warnf("the number of evals is wrong (%s instead of %s)",
			len(content.Evals), len(startEvt.JoinIDs))
		return
	}
	// if Content were faulty, voter would have been nil
	createEventId := voter.JoinEvt.Content.Parsed.(*JoinElectionContent).CreateEventId
	el := getElection(client, evt.RoomID, createEventId, elections)
	for _, joinID := range startEvt.JoinIDs {
		if _, exists := content.Evals[joinID]; !exists {
			warnf("it doesn't include an eval for join %s", joinID)
			return
		}
	}
	// is the voter's join ID in the start event's list of join IDs?
	if _, exists := content.Evals[joinEvt.ID]; !exists {
		warnf("the join event %s is not listed in the start event's join IDs", joinEvt.ID)
		return
	}

	el := elections.GetElection(startEvt.CreateID)
	if el == nil {
		log.Warnf("ignoring %s's eval msg since the election doesn't exist", evt.Sender)
		// should never happen because we retrieved the start/join
		// events above
		errorf("election %s doesn't exist", startEvt.CreateID)
		return
	}

	el.Lock()
	defer el.Save()
	defer el.Unlock()
	if el.LocalVoter == nil {

	voter := el.Joins[joinEvt.ID]
	if voter == nil {
		// should never happen because we called getJoinEvent above
		errorf("voter %s doesn't exist", joinEvt.ID)
		return
	}
	encodedEncryptedOutput, ok := content.Outputs[el.LocalVoter.JoinEvt.ID]
	if !ok {
		log.Errorf("our user ID was not included in an eval message! The election will be unable to finish; blame %s", evt.Sender)

	if voter.EvalsID != nil {
		warnf("voter submitted multiple evals events")
		return
	}

	voter.EvalsID = &evt.ID

	if el.LocalVoter == nil {
		return true
	}

	encodedEncryptedOutput, exists := content.Evals[el.LocalVoter.JoinEvt.ID]
	// If our ID doesn't exist in the keys of evalsEvtContent.Evals, our
	// JoinID wasn't included in the startEvtContent.JoinIDs (since we
	// checked that the two are equivalent above).  I'm checking membership
	// in evalsEvtContent.Evals instead of startEvtContent.JoinIDs because
	// maps are easier than slices for that.
	if !exists {
		debugf("we didn't join the election in time (or the election creator excluded us)")
		return
	}
	encryptedOutput, err := base64.StdEncoding.DecodeString(encodedEncryptedOutput)
	if err != nil {
		log.Errorf("couldn't decode %s's encrypted output: %s", evt.Sender, err)
		warnf("couldn't decode %s's encrypted output: %s", evt.Sender, err)
		return
	}

	var decryptNonce [24]byte
	copy(decryptNonce[:], encryptedOutput[:24])
	decryptedOutput, ok := box.Open(nil, encryptedOutput[24:], &decryptNonce, &voter.PubKey, &el.LocalVoter.PrivKey)
	if !ok {
		log.Errorf("decryption error")
		warnf("couldn't decrypt eval for us")
		return
	}

	voter.Eval = new(big.Int).SetBytes(decryptedOutput)

	return true
}

func OnSumMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) {
func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) {
	errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's sum msg (%s) since %s", evt.Sender, evt.ID, "%s"))

	content, ok := evt.Content.Parsed.(*SumMessageContent)
	if !ok {
		log.Warnf("ignoring %s's sum since we couldn't cast message content to SumMessageContent", evt.Sender)
		warnf("we couldn't cast message content to SumMessageContent")
		return
	}
	voter := getVoter(client, evt.RoomID, content.JoinEventId, elections)
	if voter == nil {
		log.Warnf("ignoring %s's sum since voter doesn't exist", evt.Sender)

	if incompatibleVersion(content.Version) {
		debugf("the version is incompatible")
		return
	}
	// if Content were faulty, voter would have been nil
	createEventId := voter.JoinEvt.Content.Parsed.(*JoinElectionContent).CreateEventId
	el := getElection(client, evt.RoomID, createEventId, elections)

	joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, content.JoinID)
	if joinEvt == nil {
		debugf("we couldn't get the join event, %s", content.JoinID)
		return
	}

	if len(content.EvalsIDs) == 0 {
		warnf("evals length is zero")
		return
	}

	joinIDs := make(map[id.EventID]id.EventID)
	var startID id.EventID

	for _, evalsID := range content.EvalsIDs {
		evalsEvt := elections.EventStore.GetEvalsEvent(evt.RoomID, evalsID)
		if evalsEvt == nil {
			debugf("we couldn't get an evals event, %s", evalsID)
			return
		}

		// ensure all evals have the same start ID
		if startID == "" {
			startID = evalsEvt.StartID
		} else if evalsEvt.StartID != startID {
			warnf("at least two evals have different startIDs (%s, %s, ...)", startID, evalsEvt.StartID)
			return
		}

		// ensure no two evals use the same join ID.  This'll also catch
		// duplicates in EvalsIDs
		if same, exists := joinIDs[evalsEvt.JoinID]; exists {
			warnf("at least two evals use the same join ID (%s, %s, ...)", same, evalsID)
			return
		}

		joinIDs[evalsEvt.JoinID] = evalsID

		// Remember, we don't need to check that the evals's join ID is
		// in the start event because we successfully retrieved the
		// evals event above.  Recursion is wonderful.
	}

	joinEvtContent := joinEvt.Content.Parsed.(*JoinElectionContent)
	el := elections.GetElection(joinEvtContent.CreateID)
	if el == nil {
		log.Warnf("ignoring %s's sum since the election does not exist", evt.Sender)
		// should never happen because we retrieved the start/join
		// events above
		errorf("election %s doesn't exist", joinEvtContent.CreateID)
		return
	}

	el.Lock()
	defer el.Save()
	defer el.Unlock()

	voter := el.Joins[joinEvt.ID]
	if voter == nil {
		// should never happen because we called getJoinEvent above
		errorf("voter %s doesn't exist", joinEvt.ID)
		return
	}

	if voter.SumID != nil {
		warnf("voter submitted multiple sum events")
		return
	}

	voter.SumID = &evt.ID

	bytes, err := base64.StdEncoding.DecodeString(content.Sum)
	if err != nil {
		log.Warnf("ignoring %s's sum since we couldn't decode their sum: %s", evt.Sender, err)
		warnf("we couldn't decode their sum: %s",  err)
		return
	}

	sum := new(big.Int).SetBytes(bytes)
	voter.Sum = sum

	return true
}

func OnResultMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) {
func (elections *ElectionsMap) onResultMessage(evt *event.Event) (success bool) {
	errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's result msg (%s) since %s", evt.Sender, evt.ID, "%s"))

	content, ok := evt.Content.Parsed.(*ResultMessageContent)
	if !ok {
		log.Warnf("ignoring %s's result since we couldn't cast message content to ResultMessageContent", evt.Sender)
		warnf("we couldn't cast message content to ResultMessageContent")
		return
	}
	voter := getVoter(client, evt.RoomID, content.JoinEventId, elections)
	if voter == nil {
		log.Warnf("ignoring %s's sum since voter doesn't exist", evt.Sender)
		return
	}
	createEventId := voter.JoinEvt.Content.Parsed.(*JoinElectionContent).CreateEventId
	el := getElection(client, evt.RoomID, createEventId, elections)
	if el == nil {
		log.Warnf("ignoring %s's result since the election does not exist", evt.Sender)

	if incompatibleVersion(content.Version) {
		debugf("the version is incompatible")
		return
	}
	el.Lock()
	defer el.Unlock()
	bytes, err := base64.StdEncoding.DecodeString(content.Result)
	if err != nil {
		log.Warnf("ignoring %s's result since we couldn't decode the result: %s", evt.Sender, err)

	if len(content.SumIDs) == 0 {
		warnf("sums length is zero")
		return
	}
	result := new(big.Int).SetBytes(bytes)
	voter.Result = result
}

func getElection(client *mautrix.Client, roomID id.RoomID, createEventId id.EventID, elections *ElectionsMap) *Election {
	el, exists := elections.GetOk(createEventId)
	if exists {
		return el
	}
	createEvent, err := client.GetEvent(roomID, createEventId)
	if err != nil {
		log.Warnf("couldn't retrieve election create event: %s", err)
		return nil
	}
	if createEvent.Unsigned.RedactedBecause != nil {
		log.Debug("election redacted")
		return nil
	}
	err = createEvent.Content.ParseRaw(CreateElectionMessage)
	if err != nil {
		log.Errorf("couldn't parse create event: %s", err)
		return nil
	var evalsIDs map[id.EventID]struct{}

	for _, sumID := range content.SumIDs {
		sumEvt := elections.EventStore.GetSumEvent(evt.RoomID, sumID)
		if sumEvt == nil {
			warnf("we couldn't get the sum event %s", sumID)
			return
		}
		if evalsIDs == nil {
			// run on first iteration only
			evalsIDs = make(map[id.EventID]struct{})
			for _, evalID := range sumEvt.EvalsIDs {
				evalsIDs[evalID] = struct{}{}
			}
			continue
		}
		if len(sumEvt.EvalsIDs) != len(evalsIDs) {
			warnf("sum events %s and %s don't have the same number of evals events",
				sumID, content.SumIDs[0])
		}
		for _, evalsID := range sumEvt.EvalsIDs {
			if _, exists := evalsIDs[evalsID]; !exists {
				warnf("evals ID %s exists in one evals event but not another", evalsID)
				return
			}
		}
	}
	OnCreateElectionMessage(createEvent, elections)
	el, exists = elections.GetOk(createEventId)
	if !exists {
		log.Warn("couldn't create election")
		return nil

	joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, content.JoinID)
	if joinEvt == nil {
		debugf("we couldn't get the join, %s", content.JoinID)
		return
	}
	return el
}

func getVoter(client *mautrix.Client, roomID id.RoomID, joinEventId id.EventID, elections *ElectionsMap) *Voter {
	voter, exists := elections.Joins[joinEventId]
	if exists {
		return voter
	el := elections.GetElection(joinEvt.CreateID)
	if el == nil {
		// should never happen because we retrieved the join event
		errorf("election %s doesn' exist", joinEvt.CreateID)
		return
	}
	joinEvent, err := client.GetEvent(roomID, joinEventId)
	if err != nil {
		log.Warnf("couldn't retrieve join event: %s", err)
		return nil

	voter := el.Joins[joinEvt.ID]
	if voter == nil {
		errorf("voter %s doesn't exist", joinEvt.ID)
		return
	}
	if joinEvent.Unsigned.RedactedBecause != nil {
		log.Debug("join redacted")
		return nil

	if voter.ResultID != nil {
		warnf("voter %s submitted multiple results", joinEvt.ID)
		return
	}
	err = joinEvent.Content.ParseRaw(JoinElectionMessage)

	bytes, err := base64.StdEncoding.DecodeString(content.Result)
	if err != nil {
		log.Errorf("couldn't parse join event: %s", err)
		return nil
	}
	OnJoinElectionMessage(client, joinEvent, elections)
	voter, exists = elections.Joins[joinEventId]
	if !exists {
		log.Warn("couldn't find voter")
		return nil
		warnf("we couldn't decode the result: %s", err)
		return
	}
	return voter

	el.Lock()
	defer el.Unlock()

	result := new(big.Int).SetBytes(bytes)
	voter.Result = result

	return true
}

M election/version.go => election/version.go +1 -1
@@ 1,3 1,3 @@
package election

const Version string = "0.3.0"
const Version string = "v0.3.0"

M election/voter.go => election/voter.go +90 -95
@@ 20,13 20,16 @@ import (
)

type Voter struct {
	Eval    *big.Int    `json:"eval,omitempty"`
	Input   big.Int     `json:"input"`
	JoinEvt event.Event `json:"join_evt"`
	PubKey  [32]byte    `json:"pub_key"`
	Result  *big.Int    `json:"result,omitempty"`
	Sum     *big.Int    `json:"sum,omitempty"`
	UserID  id.UserID   `json:"user_id"`
	Input    big.Int     `json:"input"`
	JoinEvt  event.Event `json:"join_evt"`
	PubKey   [32]byte    `json:"pub_key"`

	Eval     *big.Int    `json:"eval,omitempty"`
	EvalsID  *id.EventID `json:"evals_id,omitempty"`
	Result   *big.Int    `json:"result,omitempty"`
	ResultID *id.EventID `json:"result_id,omitempty"`
	Sum      *big.Int    `json:"sum,omitempty"`
	SumID    *id.EventID `json:"sum_id,omitempty"`
}

type LocalVoter struct {


@@ 36,132 39,119 @@ type LocalVoter struct {
	PrivKey [32]byte   `json:"priv_key"`
}

func NewVoter(userID id.UserID, input *big.Int, pubKey *[32]byte, joinEvt *event.Event) *Voter {
func NewVoter(input *big.Int, joinEvt *event.Event, pubKey *[32]byte) *Voter {
	return &Voter{
		Input:   *new(big.Int).Set(input),
		JoinEvt: *joinEvt,
		PubKey:  *pubKey,
		UserID:  userID,
	}
}

func CreateElection(client *mautrix.Client, candidates []Candidate, title string, roomID id.RoomID, elections *ElectionsMap) (*Election, error) {
func (elections *ElectionsMap) CreateElection(client *mautrix.Client, candidates []Candidate, title string, roomID id.RoomID) (*Election, error) {
	resp, err := client.SendMessageEvent(roomID, CreateElectionMessage, CreateElectionContent{
		Version:    Version,
		Candidates: candidates,
		Title:      title,
		Version:    Version,
	})
	if err != nil {
		return nil, err
	}

	createEvt, err := client.GetEvent(roomID, resp.EventID)
	if err != nil {
		return nil, err
	}
	err = createEvt.Content.ParseRaw(CreateElectionMessage)
	if err != nil {
		return nil, err
	createEvt := elections.EventStore.GetCreateEvent(roomID, resp.EventID)
	if createEvt == nil {
		return nil, fmt.Errorf("couldn't process our own create event, %s", resp.EventID)
	}

	OnCreateElectionMessage(createEvt, elections)
	el, exists := elections.GetOk(resp.EventID)
	if !exists {
		return nil, errors.New("couldn't create election")
	}
	return el, nil
	return elections.GetElection(createEvt.ID), nil
}

func (el *Election) JoinElection(client *mautrix.Client) error {
func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore) error {
	pubKey, privKey, err := box.GenerateKey(rand.Reader)
	if err != nil {
		return err
	}

	input, err := math.RandomBigInt(1024, false)
	if err != nil {
		return err
	}

	el.Lock()
	defer el.Save()
	defer el.Unlock()

	resp, err := client.SendMessageEvent(el.CreateEvt.RoomID, JoinElectionMessage, JoinElectionContent{
		CreateEventId: el.CreateEvt.ID,
		Input:         base64.StdEncoding.EncodeToString(input.Bytes()),
		NaclPublicKey: base64.StdEncoding.EncodeToString((*pubKey)[:]),
	resp, err := client.SendMessageEvent(el.RoomID, JoinElectionMessage, JoinElectionContent{
		Version:  Version,
		CreateID: el.CreateEvt.ID,
		Input:    base64.StdEncoding.EncodeToString(input.Bytes()),
		PubKey:   base64.StdEncoding.EncodeToString((*pubKey)[:]),
	})
	if err != nil {
		return err
	}

	joinEvt, err := client.GetEvent(el.CreateEvt.RoomID, resp.EventID)
	if err != nil {
		return err
	}
	err = joinEvt.Content.ParseRaw(JoinElectionMessage)
	if err != nil {
		return err
	joinEvt := eventStore.GetJoinEvent(el.RoomID, resp.EventID)
	if joinEvt == nil {
		return fmt.Errorf("couldn't process our own join event, %s", resp.EventID)
	}

	el.Lock()
	defer el.Save()
	defer el.Unlock()

	el.LocalVoter = &LocalVoter{
		Voter:   NewVoter(client.UserID, input, pubKey, joinEvt),
		Voter:   el.Joins[joinEvt.ID],
		PrivKey: *privKey,
	}
	el.Joins[resp.EventID] = el.LocalVoter.Voter

	return nil
}

func (el *Election) StartElection(client *mautrix.Client) error {
func (el *Election) StartElection(client *mautrix.Client, eventStore *EventStore) error {
	// TODO err from this function if we didn't create the election
	el.Lock()
	defer el.Save()
	defer el.Unlock()
	userIdMap := make(map[id.UserID]*Voter)

	el.RLock()
	// one vote per userID
	for _, voter := range el.Joins {
		prevVoter, exists := userIdMap[voter.UserID]
		if exists {
			// use latest join
			if voter.JoinEvt.Timestamp > prevVoter.JoinEvt.Timestamp {
				userIdMap[voter.UserID] = voter
			}
		} else {
			userIdMap[voter.UserID] = voter
		prevVoter, exists := userIdMap[voter.JoinEvt.Sender]
		if !exists {
			userIdMap[voter.JoinEvt.Sender] = voter
			continue
		}

		// use latest join
		if voter.JoinEvt.Timestamp > prevVoter.JoinEvt.Timestamp {
			userIdMap[voter.JoinEvt.Sender] = voter
		}
	}
	el.RUnlock()

	voters := make([]id.EventID, 0, len(userIdMap))
	for _, voter := range userIdMap {
		voters = append(voters, voter.JoinEvt.ID)
	}
	resp, err := client.SendMessageEvent(el.CreateEvt.RoomID, StartElectionMessage, StartElectionContent{
		CreateEventId: el.CreateEvt.ID,
		VoterJoinIds:  voters,

	resp, err := client.SendMessageEvent(el.RoomID, StartElectionMessage, StartElectionContent{
		Version:  Version,
		CreateID: el.CreateEvt.ID,
		JoinIDs:  voters,
	})
	if err != nil {
		return err
	}
	startEvt, err := client.GetEvent(el.CreateEvt.RoomID, resp.EventID)
	if err != nil {
		return err
	}
	err = startEvt.Content.ParseRaw(StartElectionMessage)
	if err != nil {
		return err

	startEvt := eventStore.GetStartEvent(el.RoomID, resp.EventID)
	if startEvt == nil {
		return fmt.Errorf("couldn't process our own start event, %s", resp.EventID)
	}
	// OnStartElectionMessage(client, startEvt, elections)
	el.StartEvt = startEvt
	el.FinalVoters = &voters

	return nil
}

func (el *Election) WaitForVoters(client *mautrix.Client) error {
func (el *Election) WaitForJoins(client *mautrix.Client) error {
	fmt.Println("waiting for others...")
	el.RLock()
	if el.StartEvt == nil {
		return errors.New("WaitForVoters called before election started")
	if el.StartID == nil {
		return errors.New("WaitForJoins called before election started")
	}
	finalVoters := *el.FinalVoters
	finalVoters := *el.FinalJoinIDs
	el.RUnlock()
	var wg sync.WaitGroup
	for _, voterJoinId := range finalVoters {


@@ 184,24 174,23 @@ func (el *Election) SendEvals(client *mautrix.Client) error {
	el.Lock()
	defer el.Save()
	defer el.Unlock()
	content := EvalMessageContent{
		JoinEventId: el.LocalVoter.JoinEvt.ID,
		Outputs:     make(map[id.EventID]string),
	}
	for _, voterJoinId := range *el.FinalVoters {
		voter := el.Joins[voterJoinId]
		output := el.LocalVoter.Poly.Eval(&voter.Input)
	evals := make(map[id.EventID]string)
	for _, joinID := range *el.FinalJoinIDs {
		voter := el.Joins[joinID]
		eval := el.LocalVoter.Poly.Eval(&voter.Input)
		var nonce [24]byte
		if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
			return err
		}
		if voter.JoinEvt.ID == el.LocalVoter.JoinEvt.ID {
			el.LocalVoter.Eval = output
		}
		encrypted := box.Seal(nonce[:], output.Bytes(), &nonce, &voter.PubKey, &el.LocalVoter.PrivKey)
		content.Outputs[voter.JoinEvt.ID] = base64.StdEncoding.EncodeToString(encrypted)
	}
	_, err := client.SendMessageEvent(el.CreateEvt.RoomID, EvalMessage, content)
		encrypted := box.Seal(nonce[:], eval.Bytes(), &nonce, &voter.PubKey, &el.LocalVoter.PrivKey)
		evals[voter.JoinEvt.ID] = base64.StdEncoding.EncodeToString(encrypted)
	}
	_, err := client.SendMessageEvent(el.RoomID, EvalsMessage, EvalsMessageContent{
		Version: Version,
		Evals:   evals,
		JoinID:  el.LocalVoter.JoinEvt.ID,
		StartID: *el.StartID,
	})
	if err != nil {
		el.LocalVoter.Eval = nil
		return err


@@ 211,33 200,39 @@ func (el *Election) SendEvals(client *mautrix.Client) error {

func (el *Election) SendSum(client *mautrix.Client) error {
	sum := big.NewInt(0)
	var evalsIDs []id.EventID
	var wg sync.WaitGroup
	for _, voterJoinId := range *el.FinalVoters {
	for _, voterJoinId := range *el.FinalJoinIDs {
		wg.Add(1)
		go func(voter *Voter) {
			for voter.Eval == nil {
			for voter.EvalsID == nil {
				time.Sleep(time.Millisecond * 100)
			}
			evalsIDs = append(evalsIDs, *voter.EvalsID)
			sum.Add(sum, voter.Eval)
			wg.Done()
		}(el.Joins[voterJoinId])
	}
	wg.Wait()
	_, err := client.SendMessageEvent(el.CreateEvt.RoomID, SumMessage, SumMessageContent{
		JoinEventId: el.LocalVoter.JoinEvt.ID,
		Sum:         base64.StdEncoding.EncodeToString(sum.Bytes()),
	_, err := client.SendMessageEvent(el.RoomID, SumMessage, SumMessageContent{
		Version:  Version,
		EvalsIDs: evalsIDs,
		JoinID:   el.LocalVoter.JoinEvt.ID,
		Sum:      base64.StdEncoding.EncodeToString(sum.Bytes()),
	})
	if err != nil {
		return err
	}
	el.Lock()
	defer el.Save()
	defer el.Unlock()
	el.LocalVoter.Sum = sum
	el.Save()
	return nil
}

func (el *Election) GetSums(client *mautrix.Client) {
	var wg sync.WaitGroup
	for _, voterJoinId := range *el.FinalVoters {
	for _, voterJoinId := range *el.FinalJoinIDs {
		wg.Add(1)
		go func(voter *Voter) {
			for voter.Sum == nil {


@@ 262,16 257,16 @@ func (el *Election) GetSums(client *mautrix.Client) {
}

func constructPolyMatrix(el *Election) math.Matrix {
	mat := make(math.Matrix, len(el.Joins))
	mat := make(math.Matrix, len(*el.FinalJoinIDs))

	i := 0
	for _, voterJoinId := range *el.FinalVoters {
	for _, voterJoinId := range *el.FinalJoinIDs {
		voter := el.Joins[voterJoinId]
		mat[i] = make([]big.Rat, len(mat)+1) // includes column for sum
		row := mat[i]
		row[0].SetInt64(1)
		var j int64
		for j = 1; j <= int64(len(*el.FinalVoters)-1); j++ {
		for j = 1; j <= int64(len(*el.FinalJoinIDs)-1); j++ {
			row[j].SetInt(new(big.Int).Exp(&voter.Input, big.NewInt(j), nil))
		}
		row[j].SetInt(voter.Sum)

M go.mod => go.mod +1 -0
@@ 11,6 11,7 @@ require (
	github.com/rivo/tview v0.0.0-20200528200248-fe953220389f
	github.com/sirupsen/logrus v1.2.0
	golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
	golang.org/x/mod v0.4.1
	golang.org/x/net v0.0.0-20201110031124-69a78807bb2b // indirect
	gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
	maunium.net/go/mautrix v0.7.13

M go.sum => go.sum +8 -0
@@ 58,16 58,21 @@ github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U=
github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/mod v0.4.1 h1:Kvvh58BN8Y9/lBi7hTekvtMpm07eUZ0ck5pRHpsMWrY=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM=
golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=


@@ 84,6 89,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

M ui/tui.go => ui/tui.go +11 -10
@@ 105,7 105,8 @@ func RoomTUI(client *mautrix.Client, roomID id.RoomID, elections *election.Elect
			if title == "" {
				title = "<no name>"
			}
			list.AddItem(title, fmt.Sprintf("created by %s, ID: %s", el.CreateEvt.Sender, el.CreateEvt.ID), 0, nil)
			list.AddItem(title, fmt.Sprintf("created by %s, ID: %s",
				el.CreateEvt.Sender, el.CreateEvt.ID), 0, nil)
		}
	}
	go func() {


@@ 131,7 132,7 @@ func RoomTUI(client *mautrix.Client, roomID id.RoomID, elections *election.Elect
			if el.LocalVoter == nil {
				// ask user if s/he wants to join election
				if joinElectionConfirmation(el) {
					err = el.JoinElection(client)
					err = el.JoinElection(client, elections.EventStore)
				} else {
					// user needs to select a different election
					el, err = RoomTUI(client, roomID, elections)


@@ 147,11 148,11 @@ func RoomTUI(client *mautrix.Client, roomID id.RoomID, elections *election.Elect
		}
		log.Debugf("created election title: %s", title)
		log.Debugf("created election candidates: %s", candidates)
		el, err = election.CreateElection(client, candidates, title, roomID, elections)
		el, err = elections.CreateElection(client, candidates, title, roomID)
		if err != nil {
			return
		}
		err = el.JoinElection(client)
		err = el.JoinElection(client, elections.EventStore)
		if err != nil {
			return
		}


@@ 169,7 170,7 @@ func joinElectionConfirmation(el *election.Election) (shouldJoin bool) {

	el.RLock()
	// TODO: handle when election starts while in modal
	if el.StartEvt != nil {
	if el.StartID != nil {
		buttons = []string{"Ok"}
		text = "Election has already started, sorry"
	} else {


@@ 243,13 244,13 @@ func CreateElectionTUI() (title string, candidates []election.Candidate) {
	return title, candidates
}

func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error {
func ElectionWaitTUI(client *mautrix.Client, el *election.Election, eventStore *election.EventStore) error {
	votersTextView := tview.NewTextView()
	frame := tview.NewFrame(votersTextView)
	frame.SetTitle(el.Title).SetBorder(true)
	app := newTallyardApplication()
	el.RLock()
	if el.CreateEvt.Sender == el.LocalVoter.UserID {
	if el.CreateEvt.Sender == el.LocalVoter.JoinEvt.Sender {
		frame.AddText("Press enter to start the election", false, tview.AlignCenter, tcell.ColorWhite)
		app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
			if event.Key() == tcell.KeyEnter {


@@ 257,7 258,7 @@ func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error {
				// 	frame.Clear()
				// 	frame.AddText("Starting election...", false, tview.AlignCenter, tcell.ColorWhite)
				// })
				err := el.StartElection(client)
				err := el.StartElection(client, eventStore)
				if err != nil {
					panic(err)
				}


@@ 273,7 274,7 @@ func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error {
		// TODO: handle duplicate joins from one UserID
		voters := make([]string, 0, len(el.Joins))
		for _, voter := range el.Joins {
			voters = append(voters, voter.UserID.String())
			voters = append(voters, voter.JoinEvt.Sender.String())
		}
		el.RUnlock()
		sort.Strings(voters)


@@ 287,7 288,7 @@ func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error {

			// has the election started?
			el.RLock()
			started := el.StartEvt != nil
			started := el.StartID != nil
			el.RUnlock()
			if started {
				app.Stop()