diff --git a/cmd/server/matrix.go b/cmd/server/matrix.go index ba76c21..82671d8 100644 --- a/cmd/server/matrix.go +++ b/cmd/server/matrix.go @@ -1,4 +1,4 @@ -package main +package main import ( "context" @@ -9,11 +9,13 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/cryptohelper" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) type MatrixClient struct { @@ -24,6 +26,8 @@ type MatrixClient struct { TimeLocation *time.Location httpClient *http.Client crypto *cryptohelper.CryptoHelper + cryptoClient *mautrix.Client + syncOnce sync.Once } func newMatrixClientFromEnv() *MatrixClient { @@ -59,19 +63,21 @@ func newMatrixClientFromEnv() *MatrixClient { } if readEnv("MATRIX_ENABLE_CRYPTO", "1") != "0" { - m.crypto = initMatrixCrypto(hs, token) + m.crypto, m.cryptoClient = initMatrixCrypto(hs, token) + m.startCryptoSync() } return m } -func initMatrixCrypto(hs, token string) *cryptohelper.CryptoHelper { +func initMatrixCrypto(hs, token string) (*cryptohelper.CryptoHelper, *mautrix.Client) { cli, err := mautrix.NewClient(hs, "", token) if err != nil { - return nil + return nil, nil } if whoami, err := cli.Whoami(context.Background()); err == nil && whoami != nil { cli.SetCredentials(whoami.UserID, token) + cli.DeviceID = whoami.DeviceID } storePath := readEnv("MATRIX_CRYPTO_STORE_PATH", "/tmp/farmcal-matrix-crypto.db") @@ -79,12 +85,23 @@ func initMatrixCrypto(hs, token string) *cryptohelper.CryptoHelper { helper, err := cryptohelper.NewCryptoHelper(cli, pickleKey, storePath) if err != nil { - return nil + return nil, nil } if err = helper.Init(context.Background()); err != nil { - return nil + return nil, nil } - return helper + return helper, cli +} + +func (m *MatrixClient) startCryptoSync() { + if m.cryptoClient == nil || m.crypto == nil { + return + } + m.syncOnce.Do(func() { + go func() { + _ = m.cryptoClient.SyncWithContext(context.Background()) + }() + }) } func (m *MatrixClient) FetchRecentMessages(ctx context.Context) ([]MatrixMessage, error) { @@ -126,6 +143,7 @@ func (m *MatrixClient) FetchRecentMessages(ctx context.Context) ([]MatrixMessage if err := json.Unmarshal(messagesResp.Chunk[i], ev); err != nil { continue } + ev.RoomID = id.RoomID(m.RoomID) msg, ok := m.mapEventToMessage(ctx, ev) if ok { out = append(out, msg) @@ -147,6 +165,15 @@ func (m *MatrixClient) mapEventToMessage(ctx context.Context, ev *event.Event) ( case event.EventEncrypted: if m.crypto != nil { decrypted, err := m.crypto.Decrypt(ctx, ev) + if errors.Is(err, cryptohelper.NoSessionFound) { + content := ev.Content.AsEncrypted() + m.crypto.RequestSession(ctx, ev.RoomID, content.SenderKey, content.SessionID, ev.Sender, content.DeviceID) + waitCtx, cancel := context.WithTimeout(ctx, 8*time.Second) + if m.crypto.WaitForSession(waitCtx, ev.RoomID, content.SenderKey, content.SessionID, 7*time.Second) { + decrypted, err = m.crypto.Decrypt(ctx, ev) + } + cancel() + } if err == nil && decrypted != nil { if body, ok := extractBody(decrypted); ok { return MatrixMessage{Sender: sender, Body: body, Timestamp: ts}, true