package election import ( "encoding/json" "fmt" "sync" log "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // Used to store all relevant election Events that we've tried processing. A // value of nil for a given event ID key in the Events map means that the // corresponding event was processed unsuccessfully. Likewise, a non-nil value // indicates that the event was processed successfully. // // This EventStore also uses the provided Client to fetch missing events and // pipe them through the corresponding EventHandlers when needed. type EventStore struct { sync.RWMutex Events map[id.EventID]*event.Event `json:"events"` Client *mautrix.Client `json:"-"` EventHandlers map[event.Type]func(*event.Event)bool `json:"-"` Processing map[id.EventID]<-chan struct{} `json:"-"` } func NewEventStore(client *mautrix.Client, eventHandlers map[event.Type]func(*event.Event)bool) *EventStore { return &EventStore{ Events: make(map[id.EventID]*event.Event), Client: client, EventHandlers: eventHandlers, Processing: make(map[id.EventID]<-chan struct{}), } } func (store *EventStore) UnmarshalJSON(b []byte) error { type Alias EventStore err := json.Unmarshal(b, (*Alias)(store)) if err != nil { return err } for _, event := range store.Events { if event == nil { continue } // event.Type.Class is "UnkownEventType" after unmarshlling switch event.Type.Type { case CreateElectionMessage.Type: event.Type = CreateElectionMessage case JoinElectionMessage.Type: event.Type = JoinElectionMessage case StartElectionMessage.Type: event.Type = StartElectionMessage case KeysMessage.Type: event.Type = KeysMessage case EvalsMessage.Type: event.Type = EvalsMessage case SumMessage.Type: event.Type = SumMessage } if err = event.Content.ParseRaw(event.Type); err != nil { return err } } return nil } type CreateEvent struct { *event.Event *CreateElectionContent } type JoinEvent struct { *event.Event *JoinElectionContent } type StartEvent struct { *event.Event *StartElectionContent } type KeysEvent struct { *event.Event *KeysMessageContent } type EvalsEvent struct { *event.Event *EvalsMessageContent } type SumEvent struct { *event.Event *SumMessageContent } func (store *EventStore) GetCreateEvent(roomID id.RoomID, createID id.EventID) *CreateEvent { evt, err := store.getAndHandleEvent(roomID, createID, CreateElectionMessage) if err != nil { log.Warnf("an error occurred getting create event '%s': %s", createID, err) return nil } if evt == nil { return nil } return &CreateEvent{ evt, evt.Content.Parsed.(*CreateElectionContent), } } func (store *EventStore) GetJoinEvent(roomID id.RoomID, joinID id.EventID) *JoinEvent { evt, err := store.getAndHandleEvent(roomID, joinID, JoinElectionMessage) if err != nil { log.Warnf("an error occurred getting join event '%s': %s", joinID, err) return nil } if evt == nil { return nil } return &JoinEvent{ evt, evt.Content.Parsed.(*JoinElectionContent), } } func (store *EventStore) GetStartEvent(roomID id.RoomID, startID id.EventID) *StartEvent { evt, err := store.getAndHandleEvent(roomID, startID, StartElectionMessage) if err != nil { log.Warnf("an error occurred getting start event '%s': %s", startID, err) return nil } if evt == nil { return nil } return &StartEvent{ evt, evt.Content.Parsed.(*StartElectionContent), } } func (store *EventStore) GetKeysEvent(roomID id.RoomID, keysID id.EventID) *KeysEvent { evt, err := store.getAndHandleEvent(roomID, keysID, KeysMessage) if err != nil { log.Warnf("an error occurred getting keys event '%s': %s", keysID, err) return nil } if evt == nil { return nil } return &KeysEvent{ evt, evt.Content.Parsed.(*KeysMessageContent), } } func (store *EventStore) GetEvalsEvent(roomID id.RoomID, evalsID id.EventID) *EvalsEvent { evt, err := store.getAndHandleEvent(roomID, evalsID, EvalsMessage) if err != nil { log.Warnf("an error occurred getting evals event '%s': %s", evalsID, err) return nil } if evt == nil { return nil } return &EvalsEvent{ evt, evt.Content.Parsed.(*EvalsMessageContent), } } func (store *EventStore) GetSumEvent(roomID id.RoomID, sumID id.EventID) *SumEvent { evt, err := store.getAndHandleEvent(roomID, sumID, SumMessage) if err != nil { log.Warnf("an error occurred getting sum event '%s': %s", sumID, err) return nil } if evt == nil { return nil } return &SumEvent{ evt, evt.Content.Parsed.(*SumMessageContent), } } func (store *EventStore) getAndHandleEvent(roomID id.RoomID, eventID id.EventID, eventType event.Type) (*event.Event, error) { // see if we've handled this event before store.RLock() evt, wasProcessed := store.Events[eventID] done, isProcessing := store.Processing[eventID] store.RUnlock() if wasProcessed { // This event was seen before. If evt != nil, it was handled // successfully and vice versa. return evt, nil } if isProcessing { // wait for result from handler and skip fetch log.Debugf("event %s is processing; waiting for processing to finish", eventID) <-done store.RLock() defer store.RUnlock() evt, wasProcessed = store.Events[eventID] if !wasProcessed { return nil, fmt.Errorf("event %s isn't in events map despite seeming to be processed", eventID) } if evt == nil { return nil, fmt.Errorf("event %s failed processing while we were waiting for it", eventID) } return evt, nil } // we've never seen this event before evt, err := store.fetchEvent(roomID, eventID) if err != nil { return nil, fmt.Errorf("couldn't fetch %s event '%s': %s", eventType.Type, eventID, err) } if evt == nil { return nil, fmt.Errorf("fetched %s event '%s' is nil", eventType.Type, eventID) } // TODO check version here rather than in msg? err = evt.Content.ParseRaw(eventType) if err != nil { return nil, fmt.Errorf("couldn't parse %s event '%s' content: %s", eventType.Type, eventID, err) } if !store.EventHandlers[eventType](evt) { return nil, fmt.Errorf("couldn't handle %s event '%s'", eventType.Type, eventID) } return evt, nil } func (store *EventStore) fetchEvent(roomID id.RoomID, eventID id.EventID) (*event.Event, error) { evt, err := store.Client.GetEvent(roomID, eventID) if err != nil { return nil, err } if evt.Unsigned.RedactedBecause != nil { return nil, fmt.Errorf("event %s was redacted", eventID) } return evt, nil }