package election
import (
"encoding/json"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// Used to store all relevant election Events that we've tried processing. A
// value of nil for a given event ID key in the Events map means that the
// corresponding event was processed unsuccessfully. Likewise, a non-nil value
// indicates that the event was processed successfully.
//
// This EventStore also uses the provided Client to fetch missing events and
// pipe them through the corresponding EventHandlers when needed.
type EventStore struct {
sync.RWMutex
Events map[id.EventID]*event.Event `json:"events"`
Client *mautrix.Client `json:"-"`
EventHandlers map[event.Type]func(*event.Event)bool `json:"-"`
Processing map[id.EventID]<-chan struct{} `json:"-"`
}
func NewEventStore(client *mautrix.Client, eventHandlers map[event.Type]func(*event.Event)bool) *EventStore {
return &EventStore{
Events: make(map[id.EventID]*event.Event),
Client: client,
EventHandlers: eventHandlers,
Processing: make(map[id.EventID]<-chan struct{}),
}
}
func (store *EventStore) UnmarshalJSON(b []byte) error {
type Alias EventStore
err := json.Unmarshal(b, (*Alias)(store))
if err != nil {
return err
}
for _, event := range store.Events {
if event == nil {
continue
}
// event.Type.Class is "UnkownEventType" after unmarshlling
switch event.Type.Type {
case CreateElectionMessage.Type:
event.Type = CreateElectionMessage
case JoinElectionMessage.Type:
event.Type = JoinElectionMessage
case StartElectionMessage.Type:
event.Type = StartElectionMessage
case KeysMessage.Type:
event.Type = KeysMessage
case EvalsMessage.Type:
event.Type = EvalsMessage
case SumMessage.Type:
event.Type = SumMessage
}
if err = event.Content.ParseRaw(event.Type); err != nil {
return err
}
}
return nil
}
type CreateEvent struct {
*event.Event
*CreateElectionContent
}
type JoinEvent struct {
*event.Event
*JoinElectionContent
}
type StartEvent struct {
*event.Event
*StartElectionContent
}
type KeysEvent struct {
*event.Event
*KeysMessageContent
}
type EvalsEvent struct {
*event.Event
*EvalsMessageContent
}
type SumEvent struct {
*event.Event
*SumMessageContent
}
func (store *EventStore) GetCreateEvent(roomID id.RoomID, createID id.EventID) *CreateEvent {
evt, err := store.getAndHandleEvent(roomID, createID, CreateElectionMessage)
if err != nil {
log.Warnf("an error occurred getting create event '%s': %s", createID, err)
return nil
}
if evt == nil {
return nil
}
return &CreateEvent{
evt,
evt.Content.Parsed.(*CreateElectionContent),
}
}
func (store *EventStore) GetJoinEvent(roomID id.RoomID, joinID id.EventID) *JoinEvent {
evt, err := store.getAndHandleEvent(roomID, joinID, JoinElectionMessage)
if err != nil {
log.Warnf("an error occurred getting join event '%s': %s", joinID, err)
return nil
}
if evt == nil {
return nil
}
return &JoinEvent{
evt,
evt.Content.Parsed.(*JoinElectionContent),
}
}
func (store *EventStore) GetStartEvent(roomID id.RoomID, startID id.EventID) *StartEvent {
evt, err := store.getAndHandleEvent(roomID, startID, StartElectionMessage)
if err != nil {
log.Warnf("an error occurred getting start event '%s': %s", startID, err)
return nil
}
if evt == nil {
return nil
}
return &StartEvent{
evt,
evt.Content.Parsed.(*StartElectionContent),
}
}
func (store *EventStore) GetKeysEvent(roomID id.RoomID, keysID id.EventID) *KeysEvent {
evt, err := store.getAndHandleEvent(roomID, keysID, KeysMessage)
if err != nil {
log.Warnf("an error occurred getting keys event '%s': %s", keysID, err)
return nil
}
if evt == nil {
return nil
}
return &KeysEvent{
evt,
evt.Content.Parsed.(*KeysMessageContent),
}
}
func (store *EventStore) GetEvalsEvent(roomID id.RoomID, evalsID id.EventID) *EvalsEvent {
evt, err := store.getAndHandleEvent(roomID, evalsID, EvalsMessage)
if err != nil {
log.Warnf("an error occurred getting evals event '%s': %s", evalsID, err)
return nil
}
if evt == nil {
return nil
}
return &EvalsEvent{
evt,
evt.Content.Parsed.(*EvalsMessageContent),
}
}
func (store *EventStore) GetSumEvent(roomID id.RoomID, sumID id.EventID) *SumEvent {
evt, err := store.getAndHandleEvent(roomID, sumID, SumMessage)
if err != nil {
log.Warnf("an error occurred getting sum event '%s': %s", sumID, err)
return nil
}
if evt == nil {
return nil
}
return &SumEvent{
evt,
evt.Content.Parsed.(*SumMessageContent),
}
}
func (store *EventStore) getAndHandleEvent(roomID id.RoomID, eventID id.EventID, eventType event.Type) (*event.Event, error) {
// see if we've handled this event before
store.RLock()
evt, wasProcessed := store.Events[eventID]
done, isProcessing := store.Processing[eventID]
store.RUnlock()
if wasProcessed {
// This event was seen before. If evt != nil, it was handled
// successfully and vice versa.
return evt, nil
}
if isProcessing {
// wait for result from handler and skip fetch
log.Debugf("event %s is processing; waiting for processing to finish", eventID)
<-done
store.RLock()
defer store.RUnlock()
evt, wasProcessed = store.Events[eventID]
if !wasProcessed {
return nil, fmt.Errorf("event %s isn't in events map despite seeming to be processed", eventID)
}
if evt == nil {
return nil, fmt.Errorf("event %s failed processing while we were waiting for it", eventID)
}
return evt, nil
}
// we've never seen this event before
evt, err := store.fetchEvent(roomID, eventID)
if err != nil {
return nil, fmt.Errorf("couldn't fetch %s event '%s': %s", eventType.Type, eventID, err)
}
if evt == nil {
return nil, fmt.Errorf("fetched %s event '%s' is nil", eventType.Type, eventID)
}
// TODO check version here rather than in msg?
err = evt.Content.ParseRaw(eventType)
if err != nil {
return nil, fmt.Errorf("couldn't parse %s event '%s' content: %s", eventType.Type, eventID, err)
}
if !store.EventHandlers[eventType](evt) {
return nil, fmt.Errorf("couldn't handle %s event '%s'", eventType.Type, eventID)
}
return evt, nil
}
func (store *EventStore) fetchEvent(roomID id.RoomID, eventID id.EventID) (*event.Event, error) {
evt, err := store.Client.GetEvent(roomID, eventID)
if err != nil {
return nil, err
}
if evt.Unsigned.RedactedBecause != nil {
return nil, fmt.Errorf("event %s was redacted", eventID)
}
return evt, nil
}