~edwargix/msc-link-bot

432e291f684d8b3475794d5984a4d23574784912 — David Florness 3 years ago 38e1ca0
Implement ze bot
6 files changed, 451 insertions(+), 0 deletions(-)

A .gitignore
A crypto_logger.go
M go.mod
A go.sum
A main.go
A store.go
A .gitignore => .gitignore +1 -0
@@ 0,0 1,1 @@
msc-link-bot

A crypto_logger.go => crypto_logger.go +21 -0
@@ 0,0 1,21 @@
package main

import log "github.com/sirupsen/logrus"

type cryptoLogger struct{}

func (f cryptoLogger) Error(message string, args ...interface{}) {
	log.Errorf(message, args...)
}

func (f cryptoLogger) Warn(message string, args ...interface{}) {
	log.Warnf(message, args...)
}

func (f cryptoLogger) Debug(message string, args ...interface{}) {
	log.Debugf(message, args...)
}

func (f cryptoLogger) Trace(message string, args ...interface{}) {
	log.Tracef(message, args...)
}

M go.mod => go.mod +22 -0
@@ 1,3 1,25 @@
module git.hnitbjorg.xyz/~edwargix/msc-link-bot

go 1.17

require (
	github.com/mattn/go-sqlite3 v1.14.9
	github.com/sirupsen/logrus v1.8.1
	maunium.net/go/mautrix v0.9.31
)

require (
	github.com/btcsuite/btcutil v1.0.2 // indirect
	github.com/gorilla/mux v1.8.0 // indirect
	github.com/gorilla/websocket v1.4.2 // indirect
	github.com/lib/pq v1.10.3 // indirect
	github.com/tidwall/gjson v1.10.2 // indirect
	github.com/tidwall/match v1.1.1 // indirect
	github.com/tidwall/pretty v1.2.0 // indirect
	github.com/tidwall/sjson v1.2.3 // indirect
	golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect
	golang.org/x/net v0.0.0-20211020060615-d418f374d309 // indirect
	golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect
	gopkg.in/yaml.v2 v2.4.0 // indirect
	maunium.net/go/maulogger/v2 v2.3.1 // indirect
)

A go.sum => go.sum +86 -0
@@ 0,0 1,86 @@
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ=
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA=
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg=
github.com/btcsuite/btcutil v1.0.2 h1:9iZ1Terx9fMIOtq1VrwdqfsATL9MC2l8ZrUY6YZ2uts=
github.com/btcsuite/btcutil v1.0.2/go.mod h1:j9HUFwoQRsZL3V4n+qG+CUnEGHOarIxfC3Le2Yhbcts=
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg=
github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY=
github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc=
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY=
github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs=
github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ=
github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4=
github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg=
github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tidwall/gjson v1.10.2 h1:APbLGOM0rrEkd8WBw9C24nllro4ajFuJu0Sc9hRz8Bo=
github.com/tidwall/gjson v1.10.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.3 h1:5+deguEhHSEjmuICXZ21uSSsXotWMA0orU783+Z7Cp8=
github.com/tidwall/sjson v1.2.3/go.mod h1:5WdjKx3AQMvCJ4RG6/2UYT7dLrGvJUV1x4jdTAyGvZs=
golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211020060615-d418f374d309 h1:A0lJIi+hcTR6aajJH4YqKWwohY4aW9RO7oRMcdv+HKI=
golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/maulogger/v2 v2.3.1 h1:fwBYJne0pHvJrrIPHK+TAPfyxxbBEz46oVGez2x0ODE=
maunium.net/go/maulogger/v2 v2.3.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
maunium.net/go/mautrix v0.9.31 h1:n7UF5tqq2zCyfdNsv++RyQ2anjjrFVOmOA2VkZCSgZc=
maunium.net/go/mautrix v0.9.31/go.mod h1:3U7pOAx4bxdIVJuunLDAToI+M7YwxcGMm74zBmX5aY0=

A main.go => main.go +192 -0
@@ 0,0 1,192 @@
package main

import (
	"database/sql"
	"fmt"
	"os"
	"regexp"

	_ "github.com/mattn/go-sqlite3"

	log "github.com/sirupsen/logrus"
	"maunium.net/go/mautrix"
	"maunium.net/go/mautrix/crypto"
	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"
)

var mscRegex *regexp.Regexp

func main() {
	store := NewMSCBotStore()
	client := mkClient(store)

	cryptoDB, err := sql.Open("sqlite3", "crypto.db")
	if err != nil {
		log.Fatalf("couldn't open crypto db: %v", err)
	}
	defer cryptoDB.Close()

	cryptoLogger := cryptoLogger{}
	cryptoStore := crypto.NewSQLCryptoStore(
		cryptoDB,
		"sqlite3",
		fmt.Sprintf("%s/%s", client.UserID, client.DeviceID),
		client.DeviceID,
		[]byte("xyz.hnitbjorg.msc_link_bot"),
		cryptoLogger,
	)
	err = cryptoStore.CreateTables()
	if err != nil {
		log.Fatalf("couldn't create crypto store tables: %v", err)
	}

	olmMachine := crypto.NewOlmMachine(client, cryptoLogger, cryptoStore, store)
	err = olmMachine.Load()
	if err != nil {
		log.Fatalf("couldn't load olm machine: %v", err)
	}

	mscRegex, err = regexp.Compile("\\b(?:MSC|msc)(\\d+)\\b")
	if err != nil {
		// should never happen
		log.Fatalf("couldn't compile regex: %v", err)
	}

	syncer := client.Syncer.(*mautrix.DefaultSyncer)
	syncer.OnSync(olmMachine.ProcessSyncResponse)
	syncer.OnEventType(event.StateMember, func(_ mautrix.EventSource, evt *event.Event) {
		olmMachine.HandleMemberEvent(evt)
	})
	syncer.OnEvent(store.UpdateState)
	syncer.OnEventType(event.EventMessage, func(_ mautrix.EventSource, evt *event.Event) {
		ret := getMsgResponse(client, evt)
		if ret == "" {
			return
		}
		resp, err := client.SendMessageEvent(evt.RoomID, event.EventMessage, event.MessageEventContent{
			MsgType: event.MsgText,
			Body:    ret,
		})
		if err != nil {
			log.Errorf("couldn't send event: %v", err)
			return
		}
		log.Infof("sent event %v", resp.EventID)
	})
	syncer.OnEventType(event.EventEncrypted, func(_ mautrix.EventSource, encEvt *event.Event) {
		evt, err := olmMachine.DecryptMegolmEvent(encEvt)
		if err != nil {
			log.Errorf("couldn't decrypt event %v: %v", encEvt.ID, err)
			return
		}
		if evt.Type != event.EventMessage {
			return
		}
		ret := getMsgResponse(client, evt)
		if ret == "" {
			return
		}
		content := event.MessageEventContent{
			MsgType: event.MsgText,
			Body:    ret,
		}
		encrypted, err := olmMachine.EncryptMegolmEvent(evt.RoomID, evt.Type, content)
		if err != nil {
			if isBadEncryptError(err) {
				log.Errorf("couldn't encrypt event: %v", err)
				return
			}
			log.Debugf("got %s while trying to encrypt message; sharing group session and trying again...", err)
			err = olmMachine.ShareGroupSession(evt.RoomID, store.GetRoomMembers(evt.RoomID))
			if err != nil {
				log.Errorf("couldn't share group session: %v", err)
				return
			}
			encrypted, err = olmMachine.EncryptMegolmEvent(evt.RoomID, evt.Type, content)
			if err != nil {
				log.Errorf("couldn't encrypt event(2): %v", err)
				return
			}
		}
		resp, err := client.SendMessageEvent(evt.RoomID, event.EventEncrypted, encrypted)
		if err != nil {
			log.Errorf("couldn't send encrypted event: %v", err)
			return
		}
		log.Infof("sent encrypted event %v", resp.EventID)
	})
	syncer.OnEvent(func (_ mautrix.EventSource, evt *event.Event) {
		err := olmMachine.FlushStore()
		if err != nil {
			panic(err)
		}
	})

	err = client.Sync()
	if err != nil {
		log.Fatalf("error syncing: %v", err)
	}
}

func isBadEncryptError(err error) bool {
	return err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession
}

// this function assumes evt.Type is EventMessage
// return value is the body of the message to send back, if any
func getMsgResponse(client *mautrix.Client, evt *event.Event) string {
	content := evt.Content.AsMessage()
	if content.MsgType != event.MsgText {
		return ""
	}
	mscs := getMSCs(content.Body)
	ret := ""
	for i, msc := range mscs {
		log.Infof("MSC: %v %v\n", evt.ID, msc)
		if i > 0 {
			ret += "\n"
		}
		ret += fmt.Sprintf("https://github.com/matrix-org/matrix-doc/pull/%v", msc)
	}
	return ret
}

func getMSCs(body string) (mscs []string) {
	matches := mscRegex.FindAllStringSubmatch(body, -1)
	for _, match := range matches {
		mscs = append(mscs, match[1])
	}
	return mscs
}

func mkClient(store mautrix.Storer) *mautrix.Client {
	homeserver := os.Getenv("HOMESERVER")
	if homeserver == "" {
		log.Fatal("required envvar HOMESERVER not set")
	}

	userID := os.Getenv("USER_ID")
	if userID == "" {
		log.Fatal("required envvar USER_ID not set")
	}

	deviceID := os.Getenv("DEVICE_ID")
	if deviceID == "" {
		log.Fatal("required envvar DEVICE_ID not set")
	}

	accessToken := os.Getenv("ACCESS_TOKEN")
	if accessToken == "" {
		log.Fatal("required envvar ACCESS_TOKEN not set")
	}

	client, err := mautrix.NewClient(homeserver, id.UserID(userID), accessToken)
	if err != nil {
		log.Fatalf("couldn't create client: %v", err)
	}
	client.DeviceID = id.DeviceID(deviceID)
	client.Store = store

	return client
}

A store.go => store.go +129 -0
@@ 0,0 1,129 @@
package main

import (
	"sync"

	"maunium.net/go/mautrix"
	"maunium.net/go/mautrix/event"
	"maunium.net/go/mautrix/id"
)

type MSCBotStore struct {
	sync.RWMutex
	FilterIDs map[id.UserID]string
	NextBatches map[id.UserID]string
	Rooms map[id.RoomID]*mautrix.Room
}

func NewMSCBotStore() *MSCBotStore {
	return &MSCBotStore{
		FilterIDs: make(map[id.UserID]string),
		NextBatches: make(map[id.UserID]string),
		Rooms: make(map[id.RoomID]*mautrix.Room),
	}
}

// mautrix.Storer interface implemented below

func (s *MSCBotStore) SaveFilterID(userID id.UserID, filterID string) {
	s.Lock()
	defer s.Unlock()
	s.FilterIDs[userID] = filterID
}

func (s *MSCBotStore) LoadFilterID(userID id.UserID) string {
	s.RLock()
	defer s.RUnlock()
	return s.FilterIDs[userID]
}

func (s *MSCBotStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
	s.Lock()
	defer s.Unlock()
	s.NextBatches[userID] = nextBatchToken
}

func (s *MSCBotStore) LoadNextBatch(userID id.UserID) string {
	s.RLock()
	defer s.RUnlock()
	return s.NextBatches[userID]
}

func (s *MSCBotStore) SaveRoom(room *mautrix.Room) {
	s.Lock()
	defer s.Unlock()
	s.Rooms[room.ID] = room
}

func (s *MSCBotStore) LoadRoom(roomID id.RoomID) *mautrix.Room {
	s.RLock()
	defer s.RUnlock()
	return s.Rooms[roomID]
}

func (s *MSCBotStore) UpdateState(_ mautrix.EventSource, evt *event.Event) {
	if !evt.Type.IsState() {
		return
	}
	room := s.LoadRoom(evt.RoomID)
	if room == nil {
		room = mautrix.NewRoom(evt.RoomID)
		s.SaveRoom(room)
	}
	room.UpdateState(evt)
}

// crypto.StateStore interface implemented below

// IsEncrypted returns whether a room is encrypted.
func (s *MSCBotStore) IsEncrypted(roomID id.RoomID) bool {
	s.RLock()
	defer s.RUnlock()
	if room, exists := s.Rooms[roomID]; exists {
		return room.GetStateEvent(event.StateEncryption, "") != nil
	}
	return false
}

// GetEncryptionEvent returns the encryption event's content for an encrypted room.
func (s *MSCBotStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
	s.RLock()
	defer s.RUnlock()
	room, exists := s.Rooms[roomID]
	if !exists {
		return nil
	}
	evt := room.GetStateEvent(event.StateEncryption, "")
	content, ok := evt.Content.Parsed.(*event.EncryptionEventContent)
	if !ok {
		return nil
	}
	return content
}

// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
func (s *MSCBotStore) FindSharedRooms(userID id.UserID) []id.RoomID {
	s.RLock()
	defer s.RUnlock()
	var sharedRooms []id.RoomID
	for roomID, room := range s.Rooms {
		// if room isn't encrypted, skip
		if room.GetStateEvent(event.StateEncryption, "") == nil {
			continue
		}
		if room.GetMembershipState(userID) == event.MembershipJoin {
			sharedRooms = append(sharedRooms, roomID)
		}
	}
	return sharedRooms
}

func (s *MSCBotStore) GetRoomMembers(roomID id.RoomID) []id.UserID {
	var members []id.UserID
	for userID, evt := range s.Rooms[roomID].State[event.StateMember] {
		if evt.Content.Parsed.(*event.MemberEventContent).Membership.IsInviteOrJoin() {
			members = append(members, id.UserID(userID))
		}
	}
	return members
}