From 432e291f684d8b3475794d5984a4d23574784912 Mon Sep 17 00:00:00 2001 From: David Florness Date: Fri, 29 Oct 2021 18:37:19 -0400 Subject: [PATCH] Implement ze bot --- .gitignore | 1 + crypto_logger.go | 21 ++++++ go.mod | 22 ++++++ go.sum | 86 +++++++++++++++++++++ main.go | 192 +++++++++++++++++++++++++++++++++++++++++++++++ store.go | 129 +++++++++++++++++++++++++++++++ 6 files changed, 451 insertions(+) create mode 100644 .gitignore create mode 100644 crypto_logger.go create mode 100644 go.sum create mode 100644 main.go create mode 100644 store.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c1166e7 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +msc-link-bot diff --git a/crypto_logger.go b/crypto_logger.go new file mode 100644 index 0000000..fe66d24 --- /dev/null +++ b/crypto_logger.go @@ -0,0 +1,21 @@ +package main + +import log "github.com/sirupsen/logrus" + +type cryptoLogger struct{} + +func (f cryptoLogger) Error(message string, args ...interface{}) { + log.Errorf(message, args...) +} + +func (f cryptoLogger) Warn(message string, args ...interface{}) { + log.Warnf(message, args...) +} + +func (f cryptoLogger) Debug(message string, args ...interface{}) { + log.Debugf(message, args...) +} + +func (f cryptoLogger) Trace(message string, args ...interface{}) { + log.Tracef(message, args...) +} diff --git a/go.mod b/go.mod index af6a208..6c08e60 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,25 @@ module git.hnitbjorg.xyz/~edwargix/msc-link-bot go 1.17 + +require ( + github.com/mattn/go-sqlite3 v1.14.9 + github.com/sirupsen/logrus v1.8.1 + maunium.net/go/mautrix v0.9.31 +) + +require ( + github.com/btcsuite/btcutil v1.0.2 // indirect + github.com/gorilla/mux v1.8.0 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + github.com/lib/pq v1.10.3 // indirect + github.com/tidwall/gjson v1.10.2 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.3 // indirect + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect + golang.org/x/net v0.0.0-20211020060615-d418f374d309 // indirect + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + maunium.net/go/maulogger/v2 v2.3.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e412521 --- /dev/null +++ b/go.sum @@ -0,0 +1,86 @@ +github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= +github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= +github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= +github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= +github.com/btcsuite/btcutil v1.0.2 h1:9iZ1Terx9fMIOtq1VrwdqfsATL9MC2l8ZrUY6YZ2uts= +github.com/btcsuite/btcutil v1.0.2/go.mod h1:j9HUFwoQRsZL3V4n+qG+CUnEGHOarIxfC3Le2Yhbcts= +github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg= +github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY= +github.com/btcsuite/snappy-go v0.0.0-20151229074030-0bdef8d06723/go.mod h1:8woku9dyThutzjeg+3xrA5iCpBRH8XEEg3lh6TiUghc= +github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792/go.mod h1:ghJtEyQwv5/p4Mg4C0fgbePVuGr935/5ddU9Z3TmDRY= +github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46fmI40EZs= +github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= +github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= +github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg= +github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= +github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tidwall/gjson v1.10.2 h1:APbLGOM0rrEkd8WBw9C24nllro4ajFuJu0Sc9hRz8Bo= +github.com/tidwall/gjson v1.10.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.3 h1:5+deguEhHSEjmuICXZ21uSSsXotWMA0orU783+Z7Cp8= +github.com/tidwall/sjson v1.2.3/go.mod h1:5WdjKx3AQMvCJ4RG6/2UYT7dLrGvJUV1x4jdTAyGvZs= +golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211020060615-d418f374d309 h1:A0lJIi+hcTR6aajJH4YqKWwohY4aW9RO7oRMcdv+HKI= +golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +maunium.net/go/maulogger/v2 v2.3.1 h1:fwBYJne0pHvJrrIPHK+TAPfyxxbBEz46oVGez2x0ODE= +maunium.net/go/maulogger/v2 v2.3.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= +maunium.net/go/mautrix v0.9.31 h1:n7UF5tqq2zCyfdNsv++RyQ2anjjrFVOmOA2VkZCSgZc= +maunium.net/go/mautrix v0.9.31/go.mod h1:3U7pOAx4bxdIVJuunLDAToI+M7YwxcGMm74zBmX5aY0= diff --git a/main.go b/main.go new file mode 100644 index 0000000..bb5ad8f --- /dev/null +++ b/main.go @@ -0,0 +1,192 @@ +package main + +import ( + "database/sql" + "fmt" + "os" + "regexp" + + _ "github.com/mattn/go-sqlite3" + + log "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var mscRegex *regexp.Regexp + +func main() { + store := NewMSCBotStore() + client := mkClient(store) + + cryptoDB, err := sql.Open("sqlite3", "crypto.db") + if err != nil { + log.Fatalf("couldn't open crypto db: %v", err) + } + defer cryptoDB.Close() + + cryptoLogger := cryptoLogger{} + cryptoStore := crypto.NewSQLCryptoStore( + cryptoDB, + "sqlite3", + fmt.Sprintf("%s/%s", client.UserID, client.DeviceID), + client.DeviceID, + []byte("xyz.hnitbjorg.msc_link_bot"), + cryptoLogger, + ) + err = cryptoStore.CreateTables() + if err != nil { + log.Fatalf("couldn't create crypto store tables: %v", err) + } + + olmMachine := crypto.NewOlmMachine(client, cryptoLogger, cryptoStore, store) + err = olmMachine.Load() + if err != nil { + log.Fatalf("couldn't load olm machine: %v", err) + } + + mscRegex, err = regexp.Compile("\\b(?:MSC|msc)(\\d+)\\b") + if err != nil { + // should never happen + log.Fatalf("couldn't compile regex: %v", err) + } + + syncer := client.Syncer.(*mautrix.DefaultSyncer) + syncer.OnSync(olmMachine.ProcessSyncResponse) + syncer.OnEventType(event.StateMember, func(_ mautrix.EventSource, evt *event.Event) { + olmMachine.HandleMemberEvent(evt) + }) + syncer.OnEvent(store.UpdateState) + syncer.OnEventType(event.EventMessage, func(_ mautrix.EventSource, evt *event.Event) { + ret := getMsgResponse(client, evt) + if ret == "" { + return + } + resp, err := client.SendMessageEvent(evt.RoomID, event.EventMessage, event.MessageEventContent{ + MsgType: event.MsgText, + Body: ret, + }) + if err != nil { + log.Errorf("couldn't send event: %v", err) + return + } + log.Infof("sent event %v", resp.EventID) + }) + syncer.OnEventType(event.EventEncrypted, func(_ mautrix.EventSource, encEvt *event.Event) { + evt, err := olmMachine.DecryptMegolmEvent(encEvt) + if err != nil { + log.Errorf("couldn't decrypt event %v: %v", encEvt.ID, err) + return + } + if evt.Type != event.EventMessage { + return + } + ret := getMsgResponse(client, evt) + if ret == "" { + return + } + content := event.MessageEventContent{ + MsgType: event.MsgText, + Body: ret, + } + encrypted, err := olmMachine.EncryptMegolmEvent(evt.RoomID, evt.Type, content) + if err != nil { + if isBadEncryptError(err) { + log.Errorf("couldn't encrypt event: %v", err) + return + } + log.Debugf("got %s while trying to encrypt message; sharing group session and trying again...", err) + err = olmMachine.ShareGroupSession(evt.RoomID, store.GetRoomMembers(evt.RoomID)) + if err != nil { + log.Errorf("couldn't share group session: %v", err) + return + } + encrypted, err = olmMachine.EncryptMegolmEvent(evt.RoomID, evt.Type, content) + if err != nil { + log.Errorf("couldn't encrypt event(2): %v", err) + return + } + } + resp, err := client.SendMessageEvent(evt.RoomID, event.EventEncrypted, encrypted) + if err != nil { + log.Errorf("couldn't send encrypted event: %v", err) + return + } + log.Infof("sent encrypted event %v", resp.EventID) + }) + syncer.OnEvent(func (_ mautrix.EventSource, evt *event.Event) { + err := olmMachine.FlushStore() + if err != nil { + panic(err) + } + }) + + err = client.Sync() + if err != nil { + log.Fatalf("error syncing: %v", err) + } +} + +func isBadEncryptError(err error) bool { + return err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession +} + +// this function assumes evt.Type is EventMessage +// return value is the body of the message to send back, if any +func getMsgResponse(client *mautrix.Client, evt *event.Event) string { + content := evt.Content.AsMessage() + if content.MsgType != event.MsgText { + return "" + } + mscs := getMSCs(content.Body) + ret := "" + for i, msc := range mscs { + log.Infof("MSC: %v %v\n", evt.ID, msc) + if i > 0 { + ret += "\n" + } + ret += fmt.Sprintf("https://github.com/matrix-org/matrix-doc/pull/%v", msc) + } + return ret +} + +func getMSCs(body string) (mscs []string) { + matches := mscRegex.FindAllStringSubmatch(body, -1) + for _, match := range matches { + mscs = append(mscs, match[1]) + } + return mscs +} + +func mkClient(store mautrix.Storer) *mautrix.Client { + homeserver := os.Getenv("HOMESERVER") + if homeserver == "" { + log.Fatal("required envvar HOMESERVER not set") + } + + userID := os.Getenv("USER_ID") + if userID == "" { + log.Fatal("required envvar USER_ID not set") + } + + deviceID := os.Getenv("DEVICE_ID") + if deviceID == "" { + log.Fatal("required envvar DEVICE_ID not set") + } + + accessToken := os.Getenv("ACCESS_TOKEN") + if accessToken == "" { + log.Fatal("required envvar ACCESS_TOKEN not set") + } + + client, err := mautrix.NewClient(homeserver, id.UserID(userID), accessToken) + if err != nil { + log.Fatalf("couldn't create client: %v", err) + } + client.DeviceID = id.DeviceID(deviceID) + client.Store = store + + return client +} diff --git a/store.go b/store.go new file mode 100644 index 0000000..e403eda --- /dev/null +++ b/store.go @@ -0,0 +1,129 @@ +package main + +import ( + "sync" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MSCBotStore struct { + sync.RWMutex + FilterIDs map[id.UserID]string + NextBatches map[id.UserID]string + Rooms map[id.RoomID]*mautrix.Room +} + +func NewMSCBotStore() *MSCBotStore { + return &MSCBotStore{ + FilterIDs: make(map[id.UserID]string), + NextBatches: make(map[id.UserID]string), + Rooms: make(map[id.RoomID]*mautrix.Room), + } +} + +// mautrix.Storer interface implemented below + +func (s *MSCBotStore) SaveFilterID(userID id.UserID, filterID string) { + s.Lock() + defer s.Unlock() + s.FilterIDs[userID] = filterID +} + +func (s *MSCBotStore) LoadFilterID(userID id.UserID) string { + s.RLock() + defer s.RUnlock() + return s.FilterIDs[userID] +} + +func (s *MSCBotStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { + s.Lock() + defer s.Unlock() + s.NextBatches[userID] = nextBatchToken +} + +func (s *MSCBotStore) LoadNextBatch(userID id.UserID) string { + s.RLock() + defer s.RUnlock() + return s.NextBatches[userID] +} + +func (s *MSCBotStore) SaveRoom(room *mautrix.Room) { + s.Lock() + defer s.Unlock() + s.Rooms[room.ID] = room +} + +func (s *MSCBotStore) LoadRoom(roomID id.RoomID) *mautrix.Room { + s.RLock() + defer s.RUnlock() + return s.Rooms[roomID] +} + +func (s *MSCBotStore) UpdateState(_ mautrix.EventSource, evt *event.Event) { + if !evt.Type.IsState() { + return + } + room := s.LoadRoom(evt.RoomID) + if room == nil { + room = mautrix.NewRoom(evt.RoomID) + s.SaveRoom(room) + } + room.UpdateState(evt) +} + +// crypto.StateStore interface implemented below + +// IsEncrypted returns whether a room is encrypted. +func (s *MSCBotStore) IsEncrypted(roomID id.RoomID) bool { + s.RLock() + defer s.RUnlock() + if room, exists := s.Rooms[roomID]; exists { + return room.GetStateEvent(event.StateEncryption, "") != nil + } + return false +} + +// GetEncryptionEvent returns the encryption event's content for an encrypted room. +func (s *MSCBotStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent { + s.RLock() + defer s.RUnlock() + room, exists := s.Rooms[roomID] + if !exists { + return nil + } + evt := room.GetStateEvent(event.StateEncryption, "") + content, ok := evt.Content.Parsed.(*event.EncryptionEventContent) + if !ok { + return nil + } + return content +} + +// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID. +func (s *MSCBotStore) FindSharedRooms(userID id.UserID) []id.RoomID { + s.RLock() + defer s.RUnlock() + var sharedRooms []id.RoomID + for roomID, room := range s.Rooms { + // if room isn't encrypted, skip + if room.GetStateEvent(event.StateEncryption, "") == nil { + continue + } + if room.GetMembershipState(userID) == event.MembershipJoin { + sharedRooms = append(sharedRooms, roomID) + } + } + return sharedRooms +} + +func (s *MSCBotStore) GetRoomMembers(roomID id.RoomID) []id.UserID { + var members []id.UserID + for userID, evt := range s.Rooms[roomID].State[event.StateMember] { + if evt.Content.Parsed.(*event.MemberEventContent).Membership.IsInviteOrJoin() { + members = append(members, id.UserID(userID)) + } + } + return members +} -- 2.38.4