From a313d735a011db11ea258d800d9a630737c0d169 Mon Sep 17 00:00:00 2001 From: David Florness Date: Sun, 14 Feb 2021 13:03:12 -0500 Subject: [PATCH] Try to mitigate against replay attacks and refactor event retrieval --- cmd/tallyard/main.go | 89 ++++--- election/election.go | 28 +- election/event.go | 208 +++++++++++++++ election/map.go | 118 ++++----- election/msg.go | 596 ++++++++++++++++++++++++++++++------------- election/version.go | 2 +- election/voter.go | 185 +++++++------- go.mod | 1 + go.sum | 8 + ui/tui.go | 21 +- 10 files changed, 863 insertions(+), 393 deletions(-) create mode 100644 election/event.go diff --git a/cmd/tallyard/main.go b/cmd/tallyard/main.go index c7453a9..1e55299 100644 --- a/cmd/tallyard/main.go +++ b/cmd/tallyard/main.go @@ -5,9 +5,9 @@ import ( "os" "github.com/kyoh86/xdg" + log "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" "tallyard.xyz/election" "tallyard.xyz/math" @@ -15,41 +15,50 @@ import ( "tallyard.xyz/ui" ) -func electionFilter(localUserID id.UserID) *mautrix.Filter { - return &mautrix.Filter{ - AccountData: mautrix.FilterPart{ +var electionFilter = &mautrix.Filter{ + AccountData: mautrix.FilterPart{ + NotTypes: []event.Type{event.NewEventType("*")}, + }, + Presence: mautrix.FilterPart{ + NotTypes: []event.Type{event.NewEventType("*")}, + }, + Room: mautrix.RoomFilter{ + Ephemeral: mautrix.FilterPart{ NotTypes: []event.Type{event.NewEventType("*")}, }, - Presence: mautrix.FilterPart{ - NotTypes: []event.Type{event.NewEventType("*")}, + State: mautrix.FilterPart{ + Types: []event.Type{event.StateRoomName}, }, - Room: mautrix.RoomFilter{ - Ephemeral: mautrix.FilterPart{ - NotTypes: []event.Type{event.NewEventType("*")}, - }, - State: mautrix.FilterPart{ - Types: []event.Type{event.StateRoomName}, - }, - Timeline: mautrix.FilterPart{ - LazyLoadMembers: true, - Limit: 50, - NotSenders: []id.UserID{localUserID}, - Types: []event.Type{ - election.CreateElectionMessage, - election.JoinElectionMessage, - election.StartElectionMessage, - election.EvalMessage, - election.SumMessage, - election.ResultMessage, - }, + Timeline: mautrix.FilterPart{ + LazyLoadMembers: true, + Types: []event.Type{ + election.CreateElectionMessage, + election.JoinElectionMessage, + election.StartElectionMessage, + election.EvalsMessage, + election.SumMessage, + election.ResultMessage, }, }, - } + }, } func main() { os.MkdirAll(xdg.DataHome() + "/tallyard", 0700) + // You could set this to any `io.Writer` such as a file + file, err := os.OpenFile(xdg.DataHome() + "/tallyard/tallyard.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err == nil { + log.SetLevel(log.DebugLevel) + log.SetOutput(file) + } else { + log.Errorf("failed to open logging file; using default stderr: %s", err) + } + log.Info("tallyard starting...") + defer func() { + log.Info("tallyard exiting...") + }() + authInfo, err := matrix.GetAuthInfo() if err != nil { fmt.Fprintln(os.Stderr, err) @@ -68,18 +77,16 @@ func main() { } client.Store = matrix.NewTallyardStore(elections) defer func() { - err := elections.Save() - if err != nil { - panic(err) - } + elections.Save() }() syncer := client.Syncer.(*mautrix.DefaultSyncer) + syncer.OnEvent(debugEventHook) syncer.OnEvent(client.Store.(*matrix.TallyardStore).UpdateState) - election.SetupEventHooks(client, syncer, elections) + elections.SetupEventHooks(client, syncer) go func() { - res, err := client.CreateFilter(electionFilter(client.UserID)) + res, err := client.CreateFilter(electionFilter) if err != nil { panic(err) } @@ -111,11 +118,11 @@ func main() { } // wait for election to start - err = ui.ElectionWaitTUI(client, el) + err = ui.ElectionWaitTUI(client, el, elections.EventStore) if err != nil { panic(err) } - if el.StartEvt == nil { + if el.StartID == nil { // election never started; user likely hit C-c return } @@ -130,15 +137,15 @@ func main() { } // set random poly with ballot - el.LocalVoter.Poly = math.NewRandomPoly(uint(len(*el.FinalVoters)-1), 1024, *el.LocalVoter.Ballot) + el.LocalVoter.Poly = math.NewRandomPoly(uint(len(*el.FinalJoinIDs)-1), 1024, *el.LocalVoter.Ballot) el.Save() // wait for other voters to finish - el.WaitForVoters(client) + el.WaitForJoins(client) - // send eval if we need to (if LocalVoter.Eval is set, we've already + // send evals if we need to (if LocalVoter.Evals is set, we've already // sent the event) - if el.LocalVoter.Eval == nil { + if el.LocalVoter.EvalsID == nil { err = el.SendEvals(client) if err != nil { panic(err) @@ -156,3 +163,9 @@ func main() { el.GetSums(client) } + +func debugEventHook(_ mautrix.EventSource, evt *event.Event) { + log.Debugf("<%[1]s> %[4]s (%[2]s/%[3]s)\n", + evt.Sender, evt.Type.String(), evt.ID, + evt.Content.AsMessage().Body) +} diff --git a/election/election.go b/election/election.go index 0c3cad2..0d90490 100644 --- a/election/election.go +++ b/election/election.go @@ -2,6 +2,7 @@ package election import ( "encoding/json" + "fmt" "sync" "maunium.net/go/mautrix/event" @@ -11,31 +12,34 @@ import ( type Election struct { sync.RWMutex - Candidates []Candidate `json:"candidates"` - CreateEvt event.Event `json:"create_evt"` - FinalVoters *[]id.EventID `json:"final_voters,omitempty"` - Joins map[id.EventID]*Voter `json:"joins"` - LocalVoter *LocalVoter `json:"local_voter,omitempty"` - Save func() error `json:"-"` - StartEvt *event.Event `json:"start_evt,omitempty"` - Title string `json:"title"` + Candidates []Candidate `json:"candidates"` + CreateEvt event.Event `json:"create_evt"` + FinalJoinIDs *[]id.EventID `json:"final_voters,omitempty"` + Joins map[id.EventID]*Voter `json:"joins"` + LocalVoter *LocalVoter `json:"local_voter,omitempty"` + RoomID id.RoomID `json:"room_id"` + StartID *id.EventID `json:"start_id,omitempty"` + Title string `json:"title"` + + // Save election to disk. Set by the containing ElectionsMap at runtime + Save func() `json:"-"` } -func NewElection(candidates []Candidate, createEvt event.Event, title string) *Election { +func NewElection(candidates []Candidate, createEvt event.Event, roomID id.RoomID, title string) *Election { return &Election{ Candidates: candidates, CreateEvt: createEvt, Joins: make(map[id.EventID]*Voter), + RoomID: roomID, Title: title, } } func (el *Election) UnmarshalJSON(b []byte) error { type Alias Election // create alias to prevent endless loop - tmp := (*Alias)(el) - err := json.Unmarshal(b, tmp) + err := json.Unmarshal(b, (*Alias)(el)) if err != nil { - return err + return fmt.Errorf("couldn't unmarshal election: %s", err) } // ensure these are the same pointer if el.LocalVoter != nil { diff --git a/election/event.go b/election/event.go new file mode 100644 index 0000000..d90ffc6 --- /dev/null +++ b/election/event.go @@ -0,0 +1,208 @@ +package election + +import ( + "encoding/json" + "fmt" + "sync" + + log "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// Used to store all relevant election Events that we've tried processing. A +// value of nil for a given event ID key in the Events map means that the +// corresponding event was processed unsuccessfully. Likewise, a non-nil value +// indicates that the event was processed successfully. +// +// This EventStore also uses the provided Client to fetch missing events and +// pipe them through the corresponding EventHandlers when needed. +type EventStore struct { + sync.RWMutex + Events map[id.EventID]*event.Event `json:"events"` + + Client *mautrix.Client `json:"-"` + EventHandlers map[event.Type]func(*event.Event)bool `json:"-"` +} + +func NewEventStore(client *mautrix.Client, eventHandlers map[event.Type]func(*event.Event)bool) *EventStore { + return &EventStore{ + Events: make(map[id.EventID]*event.Event), + + Client: client, + EventHandlers: eventHandlers, + } +} + +func (store *EventStore) UnmarshalJSON(b []byte) error { + type Alias EventStore + err := json.Unmarshal(b, (*Alias)(store)) + if err != nil { + return err + } + for _, event := range store.Events { + if event == nil { + continue + } + // event.Type.Class is "UnkownEventType" after unmarshlling + switch event.Type.Type { + case CreateElectionMessage.Type: + event.Type = CreateElectionMessage + case JoinElectionMessage.Type: + event.Type = JoinElectionMessage + case StartElectionMessage.Type: + event.Type = StartElectionMessage + case EvalsMessage.Type: + event.Type = EvalsMessage + case SumMessage.Type: + event.Type = SumMessage + case ResultMessage.Type: + event.Type = ResultMessage + } + if err = event.Content.ParseRaw(event.Type); err != nil { + return err + } + } + return nil +} + +type CreateEvent struct { + *event.Event + *CreateElectionContent +} + +type JoinEvent struct { + *event.Event + *JoinElectionContent +} + +type StartEvent struct { + *event.Event + *StartElectionContent +} + +type EvalsEvent struct { + *event.Event + *EvalsMessageContent +} + +type SumEvent struct { + *event.Event + *SumMessageContent +} + +func (store *EventStore) GetCreateEvent(roomID id.RoomID, createID id.EventID) *CreateEvent { + evt, err := store.getAndHandleEvent(roomID, createID, CreateElectionMessage) + if err != nil { + log.Warnf("an error occurred getting create event '%s': %s", createID, err) + return nil + } + if evt == nil { + return nil + } + return &CreateEvent{ + evt, + evt.Content.Parsed.(*CreateElectionContent), + } +} + +func (store *EventStore) GetJoinEvent(roomID id.RoomID, joinID id.EventID) *JoinEvent { + evt, err := store.getAndHandleEvent(roomID, joinID, JoinElectionMessage) + if err != nil { + log.Warnf("an error occurred getting join event '%s': %s", joinID, err) + return nil + } + if evt == nil { + return nil + } + return &JoinEvent{ + evt, + evt.Content.Parsed.(*JoinElectionContent), + } +} + +func (store *EventStore) GetStartEvent(roomID id.RoomID, startID id.EventID) *StartEvent { + evt, err := store.getAndHandleEvent(roomID, startID, StartElectionMessage) + if err != nil { + log.Warnf("an error occurred getting start event '%s': %s", startID, err) + return nil + } + if evt == nil { + return nil + } + return &StartEvent{ + evt, + evt.Content.Parsed.(*StartElectionContent), + } +} + +func (store *EventStore) GetEvalsEvent(roomID id.RoomID, evalsID id.EventID) *EvalsEvent { + evt, err := store.getAndHandleEvent(roomID, evalsID, EvalsMessage) + if err != nil { + log.Warnf("an error occurred getting evals event '%s': %s", evalsID, err) + return nil + } + if evt == nil { + return nil + } + return &EvalsEvent{ + evt, + evt.Content.Parsed.(*EvalsMessageContent), + } +} + +func (store *EventStore) GetSumEvent(roomID id.RoomID, sumID id.EventID) *SumEvent { + evt, err := store.getAndHandleEvent(roomID, sumID, EvalsMessage) + if err != nil { + log.Warnf("an error occurred getting sum event '%s': %s", sumID, err) + return nil + } + if evt == nil { + return nil + } + return &SumEvent{ + evt, + evt.Content.Parsed.(*SumMessageContent), + } +} + +func (store *EventStore) getAndHandleEvent(roomID id.RoomID, eventID id.EventID, eventType event.Type) (*event.Event, error) { + // see if we've handled this event before + store.RLock() + evt, exists := store.Events[eventID] + store.RUnlock() + if exists { + // This event was seen before. If evt != nil, it was handled + // successfully and vice versa. + return evt, nil + } + // we've never seen this event before + evt, err := store.fetchEvent(roomID, eventID) + if err != nil { + return nil, fmt.Errorf("couldn't fetch %s event '%s': %s", eventType.Type, eventID, err) + } + if evt == nil { + return nil, fmt.Errorf("fetched %s event '%s' is nil", eventType.Type, eventID) + } + // TODO check version here rather than in msg? + err = evt.Content.ParseRaw(eventType) + if err != nil { + return nil, fmt.Errorf("couldn't parse %s event '%s' content: %s", eventType.Type, eventID, err) + } + if !store.EventHandlers[eventType](evt) { + return nil, fmt.Errorf("couldn't handle %s event '%s'", eventType.Type, eventID) + } + return evt, nil +} + +func (store *EventStore) fetchEvent(roomID id.RoomID, eventID id.EventID) (*event.Event, error) { + evt, err := store.Client.GetEvent(roomID, eventID) + if err != nil { + return nil, err + } + if evt.Unsigned.RedactedBecause != nil { + return nil, fmt.Errorf("event %s was redacted", eventID) + } + return evt, nil +} diff --git a/election/map.go b/election/map.go index a44a0f5..e1327fe 100644 --- a/election/map.go +++ b/election/map.go @@ -10,48 +10,43 @@ import ( "time" "github.com/kyoh86/xdg" + log "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" ) type ElectionsMap struct { sync.RWMutex - // The version of the elections map. If the version in the file doesn't // match the current version then we can give a clear error message // telling the user to update their file to the current version. The // version in the file should only be different if there has been a // breaking change to the elections map format. - Version int `json:"version"` - + Version int `json:"version"` // Maps election create event IDs to the corresponding election. - M map[id.EventID]*Election `json:"elections"` - + Elections map[id.EventID]*Election `json:"elections"` + // See EventStore doc for explanation + EventStore *EventStore `json:"event_store,omitempty"` + // State store + NextBatch string `json:"next_batch"` + Rooms map[id.RoomID]*mautrix.Room `json:"rooms"` + UserID id.UserID `json:"user_id"` // Maps room to a list of the room's elections, reverse sorted by - // CreationTimestamp (i.e. newest to oldest). Here for convenience. - L map[id.RoomID][]*Election `json:"-"` - // The latest time L was updated - Ltime time.Time `json:"-"` - - // maps join event ID to voter. This is a convenience in scenarios when - // we don't know the exact election - Joins map[id.EventID]*Voter `json:"-"` - - // needed by the state store - NextBatch string `json:"next_batch"` - Rooms map[id.RoomID]*mautrix.Room `json:"rooms"` - UserID id.UserID `json:"userid"` + // CreationTimestamp (i.e. newest to oldest). Here for convenience for + // GUIs. Ltime represents the latest time L was updated. + L map[id.RoomID][]*Election `json:"-"` + Ltime time.Time `json:"-"` } -const electionsMapVersion = 3 +const electionsMapVersion = 4 var electionsFname = xdg.DataHome() + "/tallyard/elections.json" -func GetElectionsMap(userID id.UserID) (em *ElectionsMap, err error) { - em = &ElectionsMap{} - if _, err = os.Stat(electionsFname); os.IsNotExist(err) { +func GetElectionsMap(userID id.UserID) (*ElectionsMap, error) { + if _, err := os.Stat(electionsFname); os.IsNotExist(err) { return newElectionsMap(userID), nil } + em := &ElectionsMap{} jsonBytes, err := ioutil.ReadFile(electionsFname) if err != nil { return nil, fmt.Errorf("error reading data file %s: %s", electionsFname, err) @@ -68,19 +63,15 @@ func GetElectionsMap(userID id.UserID) (em *ElectionsMap, err error) { // TODO support multiple users? return nil, fmt.Errorf("user IDs don't match") } + // set runtime attributes for elections em.L = make(map[id.RoomID][]*Election, 0) em.Ltime = time.Now() - for createEventId, el := range em.M { - em.insort(createEventId, el) + storedElections := em.Elections + em.Elections = make(map[id.EventID]*Election, len(storedElections)) + for _, el := range storedElections { + em.AddElection(el) } - em.Joins = make(map[id.EventID]*Voter) - for _, el := range em.M { - el.Save = em.Save - for joinEventId, voter := range el.Joins { - em.Joins[joinEventId] = voter - } - } - return + return em, nil } func (em *ElectionsMap) UnmarshalJSON(b []byte) error { @@ -92,7 +83,9 @@ func (em *ElectionsMap) UnmarshalJSON(b []byte) error { for _, room := range em.Rooms { for eventType, events := range room.State { for _, evt := range events { - evt.Content.ParseRaw(eventType) + if err = evt.Content.ParseRaw(eventType); err != nil { + return err + } } } } @@ -101,68 +94,57 @@ func (em *ElectionsMap) UnmarshalJSON(b []byte) error { func newElectionsMap(userID id.UserID) *ElectionsMap { return &ElectionsMap{ - Version: electionsMapVersion, - M: make(map[id.EventID]*Election), - L: make(map[id.RoomID][]*Election), - Ltime: time.Now(), - Joins: make(map[id.EventID]*Voter), - Rooms: make(map[id.RoomID]*mautrix.Room), - UserID: userID, + Version: electionsMapVersion, + Elections: make(map[id.EventID]*Election), + L: make(map[id.RoomID][]*Election), + Ltime: time.Now(), + Rooms: make(map[id.RoomID]*mautrix.Room), + UserID: userID, } } -func (em *ElectionsMap) Save() error { +func (em *ElectionsMap) Save() { em.RLock() defer em.RUnlock() - for _, el := range em.M { + for _, el := range em.Elections { el.RLock() defer el.RUnlock() } jsonBytes, err := json.Marshal(em) if err != nil { - return fmt.Errorf("couldn't marshal elections: %s", err) + log.Errorf("couldn't marshal elections: %s", err) } err = ioutil.WriteFile(electionsFname, jsonBytes, 0600) if err != nil { - return fmt.Errorf("couldn't save elections: %s", err) + log.Errorf("couldn't save elections: %s", err) } - return nil -} - -func (em *ElectionsMap) Get(createEventID id.EventID) *Election { - em.RLock() - defer em.RUnlock() - return em.M[createEventID] + log.Info("saved elections map") } -func (em *ElectionsMap) GetOk(createEventID id.EventID) (*Election, bool) { +func (em *ElectionsMap) GetElection(createID id.EventID) *Election { em.RLock() defer em.RUnlock() - el, ok := em.M[createEventID] - return el, ok + return em.Elections[createID] } -func (em *ElectionsMap) SetIfNotExists(createEventID id.EventID, el *Election) { +func (em *ElectionsMap) AddElection(el *Election) { em.Lock() defer em.Save() defer em.Unlock() - _, exists := em.M[createEventID] + _, exists := em.Elections[el.CreateEvt.ID] if exists { + log.Warnf("election %s was already added to elections map", el.CreateEvt.ID) return } - em.set(createEventID, el) -} - -func (em *ElectionsMap) set(createEventID id.EventID, el *Election) { el.Lock() defer el.Unlock() + em.Elections[el.CreateEvt.ID] = el el.Save = em.Save - em.M[createEventID] = el - em.insort(createEventID, el) + em.insortElection(el.CreateEvt.ID, el) } -func (em *ElectionsMap) insort(createEventID id.EventID, el *Election) { - list := em.L[el.CreateEvt.RoomID] +func (em *ElectionsMap) insortElection(createID id.EventID, el *Election) { + list := em.L[el.RoomID] i := sort.Search(len(list), func(i int) bool { return list[i].CreateEvt.Timestamp < el.CreateEvt.Timestamp }) @@ -170,6 +152,12 @@ func (em *ElectionsMap) insort(createEventID id.EventID, el *Election) { copy(newList[:i], list[:i]) newList[i] = el copy(newList[i+1:], list[i:]) - em.L[el.CreateEvt.RoomID] = newList + em.L[el.RoomID] = newList em.Ltime = time.Now() } + +func (em *ElectionsMap) SetEventStore(eventStore *EventStore) { + em.Lock() + defer em.Unlock() + em.EventStore = eventStore +} diff --git a/election/msg.go b/election/msg.go index 7ce1c67..8dd0bcb 100644 --- a/election/msg.go +++ b/election/msg.go @@ -2,11 +2,13 @@ package election import ( "encoding/base64" + "fmt" "math/big" "reflect" log "github.com/sirupsen/logrus" "golang.org/x/crypto/nacl/box" + "golang.org/x/mod/semver" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -29,9 +31,9 @@ var ( Type: "xyz.tallyard.start", Class: event.MessageEventType, } - // indicate's user's evaluation of their polynomial using others' inputs - EvalMessage = event.Type{ - Type: "xyz.tallyard.eval", + // indicate's user's evaluations of their polynomial using others' inputs + EvalsMessage = event.Type{ + Type: "xyz.tallyard.evals", Class: event.MessageEventType, } // indicates a user's individual summation @@ -47,306 +49,556 @@ var ( ) type CreateElectionContent struct { + Version string `json:"version"` + Candidates []Candidate `json:"candidates"` Title string `json:"title"` - Version string `json:"version"` } type JoinElectionContent struct { - CreateEventId id.EventID `json:"create_event_id"` - Input string `json:"input"` - NaclPublicKey string `json:"nacl_public_key"` + Version string `json:"version"` + + CreateID id.EventID `json:"create_id"` + Input string `json:"input"` + PubKey string `json:"pub_key"` } type StartElectionContent struct { - CreateEventId id.EventID `json:"create_event_id"` - VoterJoinIds []id.EventID `json:"voter_join_ids"` + Version string `json:"version"` + + CreateID id.EventID `json:"create_id"` + JoinIDs []id.EventID `json:"join_ids"` } -type EvalMessageContent struct { - JoinEventId id.EventID `json:"join_event_id"` - Outputs map[id.EventID]string `json:"outputs"` +type EvalsMessageContent struct { + Version string `json:"version"` + + Evals map[id.EventID]string `json:"evals"` + JoinID id.EventID `json:"join_id"` + StartID id.EventID `json:"start_id"` } type SumMessageContent struct { - JoinEventId id.EventID `json:"join_event_id"` - Sum string `json:"sum"` + Version string `json:"version"` + + EvalsIDs []id.EventID `json:"evals_ids"` + JoinID id.EventID `json:"join_id"` + Sum string `json:"sum"` } type ResultMessageContent struct { - JoinEventId id.EventID `json:"join_event_id"` - Result string `json:"result"` + Version string `json:"version"` + + JoinID id.EventID `json:"join_id"` + Result string `json:"result"` + SumIDs []id.EventID `json:"sums"` } func init() { event.TypeMap[CreateElectionMessage] = reflect.TypeOf(CreateElectionContent{}) event.TypeMap[JoinElectionMessage] = reflect.TypeOf(JoinElectionContent{}) event.TypeMap[StartElectionMessage] = reflect.TypeOf(StartElectionContent{}) - event.TypeMap[EvalMessage] = reflect.TypeOf(EvalMessageContent{}) + event.TypeMap[EvalsMessage] = reflect.TypeOf(EvalsMessageContent{}) event.TypeMap[SumMessage] = reflect.TypeOf(SumMessageContent{}) event.TypeMap[ResultMessage] = reflect.TypeOf(ResultMessageContent{}) } -func SetupEventHooks(client *mautrix.Client, syncer mautrix.ExtensibleSyncer, elections *ElectionsMap) { - wrapper := func(f func(*event.Event)) func(mautrix.EventSource, *event.Event) { - return func(source mautrix.EventSource, evt *event.Event) { - log.Debugf("%[5]d: <%[1]s> %[4]s (%[2]s/%[3]s)\n", - evt.Sender, evt.Type.String(), evt.ID, - evt.Content.AsMessage().Body, source) +func (elections *ElectionsMap) SetupEventHooks(client *mautrix.Client, syncer mautrix.ExtensibleSyncer) { + elections.Lock() + defer elections.Unlock() + + eventHandlers := make(map[event.Type]func(*event.Event)bool) + eventStore := elections.EventStore + + if eventStore == nil { + eventStore = NewEventStore(client, eventHandlers) + } else { + eventStore.Client = client + eventStore.EventHandlers = eventHandlers + } + + wrapper := func(f func(*event.Event) bool) func(*event.Event) bool { + return func(evt *event.Event) (success bool) { if evt.Unsigned.RedactedBecause != nil { - log.Debugf("%s redacted\n", evt.ID.String()) + log.Debugf("event %s was redacted", evt.ID) return } - f(evt) + eventStore.RLock() + handledEvent, exists := eventStore.Events[evt.ID] + eventStore.RUnlock() + if exists { + log.Debugf("event %s was already handled", evt.ID) + return handledEvent != nil + } + success = f(evt) + eventStore.Lock() + // see EventStore doc for success explanation + if success { + eventStore.Events[evt.ID] = evt + } else { + eventStore.Events[evt.ID] = nil + } + eventStore.Unlock() + return } } - syncer.OnEventType(CreateElectionMessage, wrapper(func(evt *event.Event) { - OnCreateElectionMessage(evt, elections) - })) - syncer.OnEventType(JoinElectionMessage, wrapper(func(evt *event.Event) { - OnJoinElectionMessage(client, evt, elections) - })) - syncer.OnEventType(StartElectionMessage, wrapper(func(evt *event.Event) { - OnStartElectionMessage(client, evt, elections) - })) - syncer.OnEventType(EvalMessage, wrapper(func(evt *event.Event) { - OnEvalMessage(client, evt, elections) - })) - syncer.OnEventType(SumMessage, wrapper(func(evt *event.Event) { - OnSumMessage(client, evt, elections) - })) - syncer.OnEventType(ResultMessage, wrapper(func(evt *event.Event) { - OnResultMessage(client, evt, elections) - })) + + eventHandlers[CreateElectionMessage] = wrapper(elections.onCreateElectionMessage) + eventHandlers[JoinElectionMessage] = wrapper(elections.onJoinElectionMessage) + eventHandlers[StartElectionMessage] = wrapper(elections.onStartElectionMessage) + eventHandlers[EvalsMessage] = wrapper(elections.onEvalsMessage) + eventHandlers[SumMessage] = wrapper(elections.onSumMessage) + eventHandlers[ResultMessage] = wrapper(elections.onResultMessage) + + for eventType, handler := range eventHandlers { + func(handler func(*event.Event) bool) { + syncer.OnEventType(eventType, func(_ mautrix.EventSource, evt *event.Event) { + handler(evt) + }) + }(handler) + } + + elections.EventStore = eventStore } -func OnCreateElectionMessage(evt *event.Event, elections *ElectionsMap) { - // TODO: check version +type logfFunc func(format string, args ...interface{}) +// error: short-circuit bug that happened locally; +// warn: short-circuit bug that happened remotely; +// debug: short-circuit info +func logFuncs(baseFormat string) (error, warn, debug logfFunc) { + return func(reason string, args ...interface{}) { + log.Errorf(baseFormat, fmt.Sprintf(reason, args...)) + }, func(reason string, args ...interface{}) { + log.Warnf(baseFormat, fmt.Sprintf(reason, args...)) + }, func(reason string, args ...interface{}) { + log.Debugf(baseFormat, fmt.Sprintf(reason, args...)) + } +} + +func incompatibleVersion(version string) bool { + return semver.Compare( + semver.MajorMinor(version), + semver.MajorMinor(Version), + ) != 0 +} + +func (elections *ElectionsMap) onCreateElectionMessage(evt *event.Event) (success bool) { + _, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's create msg (%s) since %s", evt.Sender, evt.ID, "%s")) + content, ok := evt.Content.Parsed.(*CreateElectionContent) if !ok { - log.Warnf("ignoring %s's create since we couldn't cast message content to CreateElectionContent", evt.Sender) + warnf("we couldn't cast message content to CreateElectionContent", evt.Sender) + return + } + + if incompatibleVersion(content.Version) { + debugf("the version is incompatible") return } - elections.SetIfNotExists(evt.ID, NewElection(content.Candidates, *evt, content.Title)) + + elections.AddElection(NewElection( + content.Candidates, + *evt, + evt.RoomID, + content.Title, + )) + + return true } -func OnJoinElectionMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) { - content, ok := evt.Content.Parsed.(*JoinElectionContent) - if !ok { - log.Warnf("ignoring %s's join msg since we couldn't cast message content to JoinElectionContent", evt.Sender) +func (elections *ElectionsMap) onJoinElectionMessage(evt *event.Event) (success bool) { + errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's join msg (%s) since %s", evt.Sender, evt.ID, "%s")) + + content, exists := evt.Content.Parsed.(*JoinElectionContent) + if !exists { + warnf("we couldn't cast message content to JoinElectionContent") return } - el := getElection(client, evt.RoomID, content.CreateEventId, elections) - if el == nil { - log.Warnf("ignoring %s's join msg since the election doesn't exist", evt.Sender) + + if incompatibleVersion(content.Version) { + debugf("the version is incompatible") return } - el.Lock() - defer el.Unlock() - _, voterExists := el.Joins[evt.ID] - if voterExists { - log.Debugf("ignoring %s's join msg since we already have their info", evt.Sender) + + createEvt := elections.EventStore.GetCreateEvent(evt.RoomID, content.CreateID) + if createEvt == nil { + debugf("we couldn't get the create event, %s", content.CreateID) return } + bytes, err := base64.StdEncoding.DecodeString(content.Input) if err != nil { - log.Warnf("ignoring %s's join msg since we couldn't decode their input", evt.Sender) + warnf("we couldn't decode their input: %s", err) return } input := new(big.Int).SetBytes(bytes) - bytes, err = base64.StdEncoding.DecodeString(content.NaclPublicKey) + + bytes, err = base64.StdEncoding.DecodeString(content.PubKey) if err != nil { - log.Warnf("ignoring %s's join msg since we couldn't decode their public key: %s", evt.Sender, err) + warnf("we couldn't decode their public key: %s", err) return } var pubKey [32]byte copy(pubKey[:], bytes) - voter := NewVoter(evt.Sender, input, &pubKey, evt) - el.Joins[evt.ID] = voter - elections.Joins[evt.ID] = voter + + el := elections.GetElection(createEvt.ID) + if el == nil { + // should never happen because we retrieved the create event + // above + errorf("election %s doesn't exist", createEvt.ID) + return + } + + el.Lock() + defer el.Save() + defer el.Unlock() + + el.Joins[evt.ID] = NewVoter(input, evt, &pubKey) + + return true } -func OnStartElectionMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) { +func (elections *ElectionsMap) onStartElectionMessage(evt *event.Event) (success bool) { + errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's start msg (%s) since %s", evt.Sender, evt.ID, "%s")) + content, ok := evt.Content.Parsed.(*StartElectionContent) if !ok { - log.Warnf("ignoring %s's start msg since we couldn't cast message content to StartElectionContent", evt.Sender) + warnf("we couldn't cast message content to StartElectionContent") + return + } + + if incompatibleVersion(content.Version) { + debugf("the version is incompatible") + return + } + + createEvt := elections.EventStore.GetCreateEvent(evt.RoomID, content.CreateID) + if createEvt == nil { + debugf("we couldn't get the create event, %s", content.CreateID) return } - el := getElection(client, evt.RoomID, content.CreateEventId, elections) + + for _, joinID := range content.JoinIDs { + joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, joinID) + if joinEvt == nil { + debugf("we couldn't get the join event %s", joinID) + return + } + if joinEvt.CreateID != content.CreateID { + warnf("the join event %s didn't have the same create ID, %s, as the start event did", + joinID, content.CreateID) + return + } + } + + if createEvt.Sender != evt.Sender { + warnf("they didn't create the election") + return + } + + el := elections.GetElection(createEvt.ID) if el == nil { - log.Warnf("ignoring %s's start msg since the election doesn't exist", evt.Sender) + // should never happen because we retrieved the craete event + // above + errorf("election %s doesn't exist", createEvt.ID) return } + + // election should exist since we were able to getCreateEvent el.Lock() + defer el.Save() defer el.Unlock() - if evt.Sender != el.CreateEvt.Sender { - log.Warnf("ignoring %s's start msg since they didn't start the election", evt.Sender) - return - } - // TODO we should probably just bail when there are multiple start messages - if el.StartEvt != nil && el.StartEvt.Timestamp < evt.Timestamp { - log.Warnf("ignoring %s's start msg since the election's already been started", evt.Sender) + + if el.StartID != nil { + warnf("the election's already been started") return } - el.StartEvt = evt - el.FinalVoters = &content.VoterJoinIds - // TODO check election voters + + el.StartID = &evt.ID + el.FinalJoinIDs = &content.JoinIDs + + return true } -func OnEvalMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) { - content, ok := evt.Content.Parsed.(*EvalMessageContent) +func (elections *ElectionsMap) onEvalsMessage(evt *event.Event) (success bool) { + errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's evals msg (%s) since %s", evt.Sender, evt.ID, "%s")) + + content, ok := evt.Content.Parsed.(*EvalsMessageContent) if !ok { - log.Warn("ignoring %s's eval msg since we couldn't cast message content to EvalMessageContent", evt.Sender) + warnf("we couldn't cast message content to EvalMessageContent") return } - voter := getVoter(client, evt.RoomID, content.JoinEventId, elections) - if voter == nil { - log.Warnf("ignoring %s's eval msg since voter doesn't exist", evt.Sender) + + if incompatibleVersion(content.Version) { + debugf("the version is incompatible") + return + } + + startEvt := elections.EventStore.GetStartEvent(evt.RoomID, content.StartID) + if startEvt == nil { + debugf("we couldn't get the start event, %s", content.StartID) + return + } + + joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, content.JoinID) + if joinEvt == nil { + debugf("we couldn't get the join event, %s", content.JoinID) + return + } + + // ensure keys of Evals are in JoinIDs of start event + if len(content.Evals) != len(startEvt.JoinIDs) { + warnf("the number of evals is wrong (%s instead of %s)", + len(content.Evals), len(startEvt.JoinIDs)) return } - // if Content were faulty, voter would have been nil - createEventId := voter.JoinEvt.Content.Parsed.(*JoinElectionContent).CreateEventId - el := getElection(client, evt.RoomID, createEventId, elections) + for _, joinID := range startEvt.JoinIDs { + if _, exists := content.Evals[joinID]; !exists { + warnf("it doesn't include an eval for join %s", joinID) + return + } + } + // is the voter's join ID in the start event's list of join IDs? + if _, exists := content.Evals[joinEvt.ID]; !exists { + warnf("the join event %s is not listed in the start event's join IDs", joinEvt.ID) + return + } + + el := elections.GetElection(startEvt.CreateID) if el == nil { - log.Warnf("ignoring %s's eval msg since the election doesn't exist", evt.Sender) + // should never happen because we retrieved the start/join + // events above + errorf("election %s doesn't exist", startEvt.CreateID) return } + el.Lock() + defer el.Save() defer el.Unlock() - if el.LocalVoter == nil { + + voter := el.Joins[joinEvt.ID] + if voter == nil { + // should never happen because we called getJoinEvent above + errorf("voter %s doesn't exist", joinEvt.ID) return } - encodedEncryptedOutput, ok := content.Outputs[el.LocalVoter.JoinEvt.ID] - if !ok { - log.Errorf("our user ID was not included in an eval message! The election will be unable to finish; blame %s", evt.Sender) + + if voter.EvalsID != nil { + warnf("voter submitted multiple evals events") + return + } + + voter.EvalsID = &evt.ID + + if el.LocalVoter == nil { + return true + } + + encodedEncryptedOutput, exists := content.Evals[el.LocalVoter.JoinEvt.ID] + // If our ID doesn't exist in the keys of evalsEvtContent.Evals, our + // JoinID wasn't included in the startEvtContent.JoinIDs (since we + // checked that the two are equivalent above). I'm checking membership + // in evalsEvtContent.Evals instead of startEvtContent.JoinIDs because + // maps are easier than slices for that. + if !exists { + debugf("we didn't join the election in time (or the election creator excluded us)") return } encryptedOutput, err := base64.StdEncoding.DecodeString(encodedEncryptedOutput) if err != nil { - log.Errorf("couldn't decode %s's encrypted output: %s", evt.Sender, err) + warnf("couldn't decode %s's encrypted output: %s", evt.Sender, err) return } - var decryptNonce [24]byte copy(decryptNonce[:], encryptedOutput[:24]) decryptedOutput, ok := box.Open(nil, encryptedOutput[24:], &decryptNonce, &voter.PubKey, &el.LocalVoter.PrivKey) if !ok { - log.Errorf("decryption error") + warnf("couldn't decrypt eval for us") return } - voter.Eval = new(big.Int).SetBytes(decryptedOutput) + + return true } -func OnSumMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) { +func (elections *ElectionsMap) onSumMessage(evt *event.Event) (success bool) { + errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's sum msg (%s) since %s", evt.Sender, evt.ID, "%s")) + content, ok := evt.Content.Parsed.(*SumMessageContent) if !ok { - log.Warnf("ignoring %s's sum since we couldn't cast message content to SumMessageContent", evt.Sender) + warnf("we couldn't cast message content to SumMessageContent") return } - voter := getVoter(client, evt.RoomID, content.JoinEventId, elections) - if voter == nil { - log.Warnf("ignoring %s's sum since voter doesn't exist", evt.Sender) + + if incompatibleVersion(content.Version) { + debugf("the version is incompatible") return } - // if Content were faulty, voter would have been nil - createEventId := voter.JoinEvt.Content.Parsed.(*JoinElectionContent).CreateEventId - el := getElection(client, evt.RoomID, createEventId, elections) + + joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, content.JoinID) + if joinEvt == nil { + debugf("we couldn't get the join event, %s", content.JoinID) + return + } + + if len(content.EvalsIDs) == 0 { + warnf("evals length is zero") + return + } + + joinIDs := make(map[id.EventID]id.EventID) + var startID id.EventID + + for _, evalsID := range content.EvalsIDs { + evalsEvt := elections.EventStore.GetEvalsEvent(evt.RoomID, evalsID) + if evalsEvt == nil { + debugf("we couldn't get an evals event, %s", evalsID) + return + } + + // ensure all evals have the same start ID + if startID == "" { + startID = evalsEvt.StartID + } else if evalsEvt.StartID != startID { + warnf("at least two evals have different startIDs (%s, %s, ...)", startID, evalsEvt.StartID) + return + } + + // ensure no two evals use the same join ID. This'll also catch + // duplicates in EvalsIDs + if same, exists := joinIDs[evalsEvt.JoinID]; exists { + warnf("at least two evals use the same join ID (%s, %s, ...)", same, evalsID) + return + } + + joinIDs[evalsEvt.JoinID] = evalsID + + // Remember, we don't need to check that the evals's join ID is + // in the start event because we successfully retrieved the + // evals event above. Recursion is wonderful. + } + + joinEvtContent := joinEvt.Content.Parsed.(*JoinElectionContent) + el := elections.GetElection(joinEvtContent.CreateID) if el == nil { - log.Warnf("ignoring %s's sum since the election does not exist", evt.Sender) + // should never happen because we retrieved the start/join + // events above + errorf("election %s doesn't exist", joinEvtContent.CreateID) return } + el.Lock() + defer el.Save() defer el.Unlock() + + voter := el.Joins[joinEvt.ID] + if voter == nil { + // should never happen because we called getJoinEvent above + errorf("voter %s doesn't exist", joinEvt.ID) + return + } + + if voter.SumID != nil { + warnf("voter submitted multiple sum events") + return + } + + voter.SumID = &evt.ID + bytes, err := base64.StdEncoding.DecodeString(content.Sum) if err != nil { - log.Warnf("ignoring %s's sum since we couldn't decode their sum: %s", evt.Sender, err) + warnf("we couldn't decode their sum: %s", err) return } + sum := new(big.Int).SetBytes(bytes) voter.Sum = sum + + return true } -func OnResultMessage(client *mautrix.Client, evt *event.Event, elections *ElectionsMap) { +func (elections *ElectionsMap) onResultMessage(evt *event.Event) (success bool) { + errorf, warnf, debugf := logFuncs(fmt.Sprintf("ignoring %s's result msg (%s) since %s", evt.Sender, evt.ID, "%s")) + content, ok := evt.Content.Parsed.(*ResultMessageContent) if !ok { - log.Warnf("ignoring %s's result since we couldn't cast message content to ResultMessageContent", evt.Sender) + warnf("we couldn't cast message content to ResultMessageContent") return } - voter := getVoter(client, evt.RoomID, content.JoinEventId, elections) - if voter == nil { - log.Warnf("ignoring %s's sum since voter doesn't exist", evt.Sender) - return - } - createEventId := voter.JoinEvt.Content.Parsed.(*JoinElectionContent).CreateEventId - el := getElection(client, evt.RoomID, createEventId, elections) - if el == nil { - log.Warnf("ignoring %s's result since the election does not exist", evt.Sender) + + if incompatibleVersion(content.Version) { + debugf("the version is incompatible") return } - el.Lock() - defer el.Unlock() - bytes, err := base64.StdEncoding.DecodeString(content.Result) - if err != nil { - log.Warnf("ignoring %s's result since we couldn't decode the result: %s", evt.Sender, err) + + if len(content.SumIDs) == 0 { + warnf("sums length is zero") return } - result := new(big.Int).SetBytes(bytes) - voter.Result = result -} -func getElection(client *mautrix.Client, roomID id.RoomID, createEventId id.EventID, elections *ElectionsMap) *Election { - el, exists := elections.GetOk(createEventId) - if exists { - return el - } - createEvent, err := client.GetEvent(roomID, createEventId) - if err != nil { - log.Warnf("couldn't retrieve election create event: %s", err) - return nil - } - if createEvent.Unsigned.RedactedBecause != nil { - log.Debug("election redacted") - return nil - } - err = createEvent.Content.ParseRaw(CreateElectionMessage) - if err != nil { - log.Errorf("couldn't parse create event: %s", err) - return nil + var evalsIDs map[id.EventID]struct{} + + for _, sumID := range content.SumIDs { + sumEvt := elections.EventStore.GetSumEvent(evt.RoomID, sumID) + if sumEvt == nil { + warnf("we couldn't get the sum event %s", sumID) + return + } + if evalsIDs == nil { + // run on first iteration only + evalsIDs = make(map[id.EventID]struct{}) + for _, evalID := range sumEvt.EvalsIDs { + evalsIDs[evalID] = struct{}{} + } + continue + } + if len(sumEvt.EvalsIDs) != len(evalsIDs) { + warnf("sum events %s and %s don't have the same number of evals events", + sumID, content.SumIDs[0]) + } + for _, evalsID := range sumEvt.EvalsIDs { + if _, exists := evalsIDs[evalsID]; !exists { + warnf("evals ID %s exists in one evals event but not another", evalsID) + return + } + } } - OnCreateElectionMessage(createEvent, elections) - el, exists = elections.GetOk(createEventId) - if !exists { - log.Warn("couldn't create election") - return nil + + joinEvt := elections.EventStore.GetJoinEvent(evt.RoomID, content.JoinID) + if joinEvt == nil { + debugf("we couldn't get the join, %s", content.JoinID) + return } - return el -} -func getVoter(client *mautrix.Client, roomID id.RoomID, joinEventId id.EventID, elections *ElectionsMap) *Voter { - voter, exists := elections.Joins[joinEventId] - if exists { - return voter + el := elections.GetElection(joinEvt.CreateID) + if el == nil { + // should never happen because we retrieved the join event + errorf("election %s doesn' exist", joinEvt.CreateID) + return } - joinEvent, err := client.GetEvent(roomID, joinEventId) - if err != nil { - log.Warnf("couldn't retrieve join event: %s", err) - return nil + + voter := el.Joins[joinEvt.ID] + if voter == nil { + errorf("voter %s doesn't exist", joinEvt.ID) + return } - if joinEvent.Unsigned.RedactedBecause != nil { - log.Debug("join redacted") - return nil + + if voter.ResultID != nil { + warnf("voter %s submitted multiple results", joinEvt.ID) + return } - err = joinEvent.Content.ParseRaw(JoinElectionMessage) + + bytes, err := base64.StdEncoding.DecodeString(content.Result) if err != nil { - log.Errorf("couldn't parse join event: %s", err) - return nil - } - OnJoinElectionMessage(client, joinEvent, elections) - voter, exists = elections.Joins[joinEventId] - if !exists { - log.Warn("couldn't find voter") - return nil + warnf("we couldn't decode the result: %s", err) + return } - return voter + + el.Lock() + defer el.Unlock() + + result := new(big.Int).SetBytes(bytes) + voter.Result = result + + return true } diff --git a/election/version.go b/election/version.go index a13309a..a318485 100644 --- a/election/version.go +++ b/election/version.go @@ -1,3 +1,3 @@ package election -const Version string = "0.3.0" +const Version string = "v0.3.0" diff --git a/election/voter.go b/election/voter.go index ebd5d6f..6618e39 100644 --- a/election/voter.go +++ b/election/voter.go @@ -20,13 +20,16 @@ import ( ) type Voter struct { - Eval *big.Int `json:"eval,omitempty"` - Input big.Int `json:"input"` - JoinEvt event.Event `json:"join_evt"` - PubKey [32]byte `json:"pub_key"` - Result *big.Int `json:"result,omitempty"` - Sum *big.Int `json:"sum,omitempty"` - UserID id.UserID `json:"user_id"` + Input big.Int `json:"input"` + JoinEvt event.Event `json:"join_evt"` + PubKey [32]byte `json:"pub_key"` + + Eval *big.Int `json:"eval,omitempty"` + EvalsID *id.EventID `json:"evals_id,omitempty"` + Result *big.Int `json:"result,omitempty"` + ResultID *id.EventID `json:"result_id,omitempty"` + Sum *big.Int `json:"sum,omitempty"` + SumID *id.EventID `json:"sum_id,omitempty"` } type LocalVoter struct { @@ -36,132 +39,119 @@ type LocalVoter struct { PrivKey [32]byte `json:"priv_key"` } -func NewVoter(userID id.UserID, input *big.Int, pubKey *[32]byte, joinEvt *event.Event) *Voter { +func NewVoter(input *big.Int, joinEvt *event.Event, pubKey *[32]byte) *Voter { return &Voter{ Input: *new(big.Int).Set(input), JoinEvt: *joinEvt, PubKey: *pubKey, - UserID: userID, } } -func CreateElection(client *mautrix.Client, candidates []Candidate, title string, roomID id.RoomID, elections *ElectionsMap) (*Election, error) { +func (elections *ElectionsMap) CreateElection(client *mautrix.Client, candidates []Candidate, title string, roomID id.RoomID) (*Election, error) { resp, err := client.SendMessageEvent(roomID, CreateElectionMessage, CreateElectionContent{ + Version: Version, Candidates: candidates, Title: title, - Version: Version, }) if err != nil { return nil, err } - createEvt, err := client.GetEvent(roomID, resp.EventID) - if err != nil { - return nil, err - } - err = createEvt.Content.ParseRaw(CreateElectionMessage) - if err != nil { - return nil, err + createEvt := elections.EventStore.GetCreateEvent(roomID, resp.EventID) + if createEvt == nil { + return nil, fmt.Errorf("couldn't process our own create event, %s", resp.EventID) } - OnCreateElectionMessage(createEvt, elections) - el, exists := elections.GetOk(resp.EventID) - if !exists { - return nil, errors.New("couldn't create election") - } - return el, nil + return elections.GetElection(createEvt.ID), nil } -func (el *Election) JoinElection(client *mautrix.Client) error { +func (el *Election) JoinElection(client *mautrix.Client, eventStore *EventStore) error { pubKey, privKey, err := box.GenerateKey(rand.Reader) if err != nil { return err } + input, err := math.RandomBigInt(1024, false) if err != nil { return err } - el.Lock() - defer el.Save() - defer el.Unlock() - - resp, err := client.SendMessageEvent(el.CreateEvt.RoomID, JoinElectionMessage, JoinElectionContent{ - CreateEventId: el.CreateEvt.ID, - Input: base64.StdEncoding.EncodeToString(input.Bytes()), - NaclPublicKey: base64.StdEncoding.EncodeToString((*pubKey)[:]), + resp, err := client.SendMessageEvent(el.RoomID, JoinElectionMessage, JoinElectionContent{ + Version: Version, + CreateID: el.CreateEvt.ID, + Input: base64.StdEncoding.EncodeToString(input.Bytes()), + PubKey: base64.StdEncoding.EncodeToString((*pubKey)[:]), }) if err != nil { return err } - joinEvt, err := client.GetEvent(el.CreateEvt.RoomID, resp.EventID) - if err != nil { - return err - } - err = joinEvt.Content.ParseRaw(JoinElectionMessage) - if err != nil { - return err + joinEvt := eventStore.GetJoinEvent(el.RoomID, resp.EventID) + if joinEvt == nil { + return fmt.Errorf("couldn't process our own join event, %s", resp.EventID) } + el.Lock() + defer el.Save() + defer el.Unlock() + el.LocalVoter = &LocalVoter{ - Voter: NewVoter(client.UserID, input, pubKey, joinEvt), + Voter: el.Joins[joinEvt.ID], PrivKey: *privKey, } - el.Joins[resp.EventID] = el.LocalVoter.Voter + return nil } -func (el *Election) StartElection(client *mautrix.Client) error { +func (el *Election) StartElection(client *mautrix.Client, eventStore *EventStore) error { // TODO err from this function if we didn't create the election - el.Lock() - defer el.Save() - defer el.Unlock() userIdMap := make(map[id.UserID]*Voter) + + el.RLock() // one vote per userID for _, voter := range el.Joins { - prevVoter, exists := userIdMap[voter.UserID] - if exists { - // use latest join - if voter.JoinEvt.Timestamp > prevVoter.JoinEvt.Timestamp { - userIdMap[voter.UserID] = voter - } - } else { - userIdMap[voter.UserID] = voter + prevVoter, exists := userIdMap[voter.JoinEvt.Sender] + if !exists { + userIdMap[voter.JoinEvt.Sender] = voter + continue + } + + // use latest join + if voter.JoinEvt.Timestamp > prevVoter.JoinEvt.Timestamp { + userIdMap[voter.JoinEvt.Sender] = voter } } + el.RUnlock() + voters := make([]id.EventID, 0, len(userIdMap)) for _, voter := range userIdMap { voters = append(voters, voter.JoinEvt.ID) } - resp, err := client.SendMessageEvent(el.CreateEvt.RoomID, StartElectionMessage, StartElectionContent{ - CreateEventId: el.CreateEvt.ID, - VoterJoinIds: voters, + + resp, err := client.SendMessageEvent(el.RoomID, StartElectionMessage, StartElectionContent{ + Version: Version, + CreateID: el.CreateEvt.ID, + JoinIDs: voters, }) if err != nil { return err } - startEvt, err := client.GetEvent(el.CreateEvt.RoomID, resp.EventID) - if err != nil { - return err - } - err = startEvt.Content.ParseRaw(StartElectionMessage) - if err != nil { - return err + + startEvt := eventStore.GetStartEvent(el.RoomID, resp.EventID) + if startEvt == nil { + return fmt.Errorf("couldn't process our own start event, %s", resp.EventID) } - // OnStartElectionMessage(client, startEvt, elections) - el.StartEvt = startEvt - el.FinalVoters = &voters + return nil } -func (el *Election) WaitForVoters(client *mautrix.Client) error { +func (el *Election) WaitForJoins(client *mautrix.Client) error { fmt.Println("waiting for others...") el.RLock() - if el.StartEvt == nil { - return errors.New("WaitForVoters called before election started") + if el.StartID == nil { + return errors.New("WaitForJoins called before election started") } - finalVoters := *el.FinalVoters + finalVoters := *el.FinalJoinIDs el.RUnlock() var wg sync.WaitGroup for _, voterJoinId := range finalVoters { @@ -184,24 +174,23 @@ func (el *Election) SendEvals(client *mautrix.Client) error { el.Lock() defer el.Save() defer el.Unlock() - content := EvalMessageContent{ - JoinEventId: el.LocalVoter.JoinEvt.ID, - Outputs: make(map[id.EventID]string), - } - for _, voterJoinId := range *el.FinalVoters { - voter := el.Joins[voterJoinId] - output := el.LocalVoter.Poly.Eval(&voter.Input) + evals := make(map[id.EventID]string) + for _, joinID := range *el.FinalJoinIDs { + voter := el.Joins[joinID] + eval := el.LocalVoter.Poly.Eval(&voter.Input) var nonce [24]byte if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { return err } - if voter.JoinEvt.ID == el.LocalVoter.JoinEvt.ID { - el.LocalVoter.Eval = output - } - encrypted := box.Seal(nonce[:], output.Bytes(), &nonce, &voter.PubKey, &el.LocalVoter.PrivKey) - content.Outputs[voter.JoinEvt.ID] = base64.StdEncoding.EncodeToString(encrypted) - } - _, err := client.SendMessageEvent(el.CreateEvt.RoomID, EvalMessage, content) + encrypted := box.Seal(nonce[:], eval.Bytes(), &nonce, &voter.PubKey, &el.LocalVoter.PrivKey) + evals[voter.JoinEvt.ID] = base64.StdEncoding.EncodeToString(encrypted) + } + _, err := client.SendMessageEvent(el.RoomID, EvalsMessage, EvalsMessageContent{ + Version: Version, + Evals: evals, + JoinID: el.LocalVoter.JoinEvt.ID, + StartID: *el.StartID, + }) if err != nil { el.LocalVoter.Eval = nil return err @@ -211,33 +200,39 @@ func (el *Election) SendEvals(client *mautrix.Client) error { func (el *Election) SendSum(client *mautrix.Client) error { sum := big.NewInt(0) + var evalsIDs []id.EventID var wg sync.WaitGroup - for _, voterJoinId := range *el.FinalVoters { + for _, voterJoinId := range *el.FinalJoinIDs { wg.Add(1) go func(voter *Voter) { - for voter.Eval == nil { + for voter.EvalsID == nil { time.Sleep(time.Millisecond * 100) } + evalsIDs = append(evalsIDs, *voter.EvalsID) sum.Add(sum, voter.Eval) wg.Done() }(el.Joins[voterJoinId]) } wg.Wait() - _, err := client.SendMessageEvent(el.CreateEvt.RoomID, SumMessage, SumMessageContent{ - JoinEventId: el.LocalVoter.JoinEvt.ID, - Sum: base64.StdEncoding.EncodeToString(sum.Bytes()), + _, err := client.SendMessageEvent(el.RoomID, SumMessage, SumMessageContent{ + Version: Version, + EvalsIDs: evalsIDs, + JoinID: el.LocalVoter.JoinEvt.ID, + Sum: base64.StdEncoding.EncodeToString(sum.Bytes()), }) if err != nil { return err } + el.Lock() + defer el.Save() + defer el.Unlock() el.LocalVoter.Sum = sum - el.Save() return nil } func (el *Election) GetSums(client *mautrix.Client) { var wg sync.WaitGroup - for _, voterJoinId := range *el.FinalVoters { + for _, voterJoinId := range *el.FinalJoinIDs { wg.Add(1) go func(voter *Voter) { for voter.Sum == nil { @@ -262,16 +257,16 @@ func (el *Election) GetSums(client *mautrix.Client) { } func constructPolyMatrix(el *Election) math.Matrix { - mat := make(math.Matrix, len(el.Joins)) + mat := make(math.Matrix, len(*el.FinalJoinIDs)) i := 0 - for _, voterJoinId := range *el.FinalVoters { + for _, voterJoinId := range *el.FinalJoinIDs { voter := el.Joins[voterJoinId] mat[i] = make([]big.Rat, len(mat)+1) // includes column for sum row := mat[i] row[0].SetInt64(1) var j int64 - for j = 1; j <= int64(len(*el.FinalVoters)-1); j++ { + for j = 1; j <= int64(len(*el.FinalJoinIDs)-1); j++ { row[j].SetInt(new(big.Int).Exp(&voter.Input, big.NewInt(j), nil)) } row[j].SetInt(voter.Sum) diff --git a/go.mod b/go.mod index 052c7e4..475ea1c 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/rivo/tview v0.0.0-20200528200248-fe953220389f github.com/sirupsen/logrus v1.2.0 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/mod v0.4.1 golang.org/x/net v0.0.0-20201110031124-69a78807bb2b // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect maunium.net/go/mautrix v0.7.13 diff --git a/go.sum b/go.sum index 33a2439..342d632 100644 --- a/go.sum +++ b/go.sum @@ -58,16 +58,21 @@ github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U= github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.4.1 h1:Kvvh58BN8Y9/lBi7hTekvtMpm07eUZ0ck5pRHpsMWrY= +golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM= golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -84,6 +89,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/ui/tui.go b/ui/tui.go index c7327a7..5b726d6 100644 --- a/ui/tui.go +++ b/ui/tui.go @@ -105,7 +105,8 @@ func RoomTUI(client *mautrix.Client, roomID id.RoomID, elections *election.Elect if title == "" { title = "" } - list.AddItem(title, fmt.Sprintf("created by %s, ID: %s", el.CreateEvt.Sender, el.CreateEvt.ID), 0, nil) + list.AddItem(title, fmt.Sprintf("created by %s, ID: %s", + el.CreateEvt.Sender, el.CreateEvt.ID), 0, nil) } } go func() { @@ -131,7 +132,7 @@ func RoomTUI(client *mautrix.Client, roomID id.RoomID, elections *election.Elect if el.LocalVoter == nil { // ask user if s/he wants to join election if joinElectionConfirmation(el) { - err = el.JoinElection(client) + err = el.JoinElection(client, elections.EventStore) } else { // user needs to select a different election el, err = RoomTUI(client, roomID, elections) @@ -147,11 +148,11 @@ func RoomTUI(client *mautrix.Client, roomID id.RoomID, elections *election.Elect } log.Debugf("created election title: %s", title) log.Debugf("created election candidates: %s", candidates) - el, err = election.CreateElection(client, candidates, title, roomID, elections) + el, err = elections.CreateElection(client, candidates, title, roomID) if err != nil { return } - err = el.JoinElection(client) + err = el.JoinElection(client, elections.EventStore) if err != nil { return } @@ -169,7 +170,7 @@ func joinElectionConfirmation(el *election.Election) (shouldJoin bool) { el.RLock() // TODO: handle when election starts while in modal - if el.StartEvt != nil { + if el.StartID != nil { buttons = []string{"Ok"} text = "Election has already started, sorry" } else { @@ -243,13 +244,13 @@ func CreateElectionTUI() (title string, candidates []election.Candidate) { return title, candidates } -func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error { +func ElectionWaitTUI(client *mautrix.Client, el *election.Election, eventStore *election.EventStore) error { votersTextView := tview.NewTextView() frame := tview.NewFrame(votersTextView) frame.SetTitle(el.Title).SetBorder(true) app := newTallyardApplication() el.RLock() - if el.CreateEvt.Sender == el.LocalVoter.UserID { + if el.CreateEvt.Sender == el.LocalVoter.JoinEvt.Sender { frame.AddText("Press enter to start the election", false, tview.AlignCenter, tcell.ColorWhite) app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEnter { @@ -257,7 +258,7 @@ func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error { // frame.Clear() // frame.AddText("Starting election...", false, tview.AlignCenter, tcell.ColorWhite) // }) - err := el.StartElection(client) + err := el.StartElection(client, eventStore) if err != nil { panic(err) } @@ -273,7 +274,7 @@ func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error { // TODO: handle duplicate joins from one UserID voters := make([]string, 0, len(el.Joins)) for _, voter := range el.Joins { - voters = append(voters, voter.UserID.String()) + voters = append(voters, voter.JoinEvt.Sender.String()) } el.RUnlock() sort.Strings(voters) @@ -287,7 +288,7 @@ func ElectionWaitTUI(client *mautrix.Client, el *election.Election) error { // has the election started? el.RLock() - started := el.StartEvt != nil + started := el.StartID != nil el.RUnlock() if started { app.Stop() -- 2.38.4