Skip to content

Commit

Permalink
Allow disabling automatic key fetching for Olm machine
Browse files Browse the repository at this point in the history
Many crypto operations in the Olm machine have a possible side effect of
fetching keys from the server if they are missing. This may be undesired
in some special cases.

To tracking which users need key fetching, CryptoStore now exposes APIs
to mark and query the status.
  • Loading branch information
hifi committed Jan 10, 2024
1 parent 8da3a17 commit 20de54d
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 26 deletions.
24 changes: 14 additions & 10 deletions crypto/devicelist.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@ var (
InvalidKeySignature = errors.New("invalid signature on device keys")
)

func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) map[id.DeviceID]*id.Device {
return mach.fetchKeys(ctx, []id.UserID{user}, "", true)[user]
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
log := zerolog.Ctx(ctx)

if keys, err := mach.FetchKeys(ctx, []id.UserID{user}, true); err != nil {
log.Err(err).Msg("Failed to load devices")
} else if keys != nil {
return keys[user]
}

return nil
}

func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
Expand Down Expand Up @@ -85,19 +93,16 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
}
}

func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
// TODO this function should probably return errors
func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) {
req := &mautrix.ReqQueryKeys{
DeviceKeys: mautrix.DeviceKeysRequest{},
Timeout: 10 * 1000,
Token: sinceToken,
}
log := mach.machOrContextLog(ctx)
if !includeUntracked {
var err error
users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users)
if err != nil {
log.Warn().Err(err).Msg("Failed to filter tracked user list")
return nil, fmt.Errorf("failed to filter tracked user list: %w", err)
}
}
if len(users) == 0 {
Expand All @@ -109,8 +114,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users")
resp, err := mach.Client.QueryKeys(ctx, req)
if err != nil {
log.Error().Err(err).Msg("Failed to query keys")
return
return nil, fmt.Errorf("failed to query keys: %w", err)
}
for server, err := range resp.Failures {
log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server")
Expand Down Expand Up @@ -189,7 +193,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys)

return data
return data, nil
}

// OnDevicesChanged finds all shared rooms with the given user and invalidates outbound sessions in those rooms.
Expand Down
22 changes: 15 additions & 7 deletions crypto/encryptmegolm.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,21 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
}

if len(fetchKeys) > 0 {
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
log.Debug().
Int("device_count", len(devices)).
Str("target_user_id", userID.String()).
Msg("Got device keys for user")
missingSessions[userID] = devices
if mach.DisableKeyFetching {
log.Warn().Strs("users", strishArray(fetchKeys)).Msg("Keys missing but key fetching is disabled")
} else {
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil {
log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys")
} else if keys != nil {
for userID, devices := range keys {
log.Debug().
Int("device_count", len(devices)).
Str("target_user_id", userID.String()).
Msg("Got device keys for user")
missingSessions[userID] = devices
}
}
}
}

Expand Down
18 changes: 13 additions & 5 deletions crypto/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ type OlmMachine struct {

PlaintextMentions bool

// Never ask the server for keys automatically as a side effect.
DisableKeyFetching bool

SendKeysMinTrust id.TrustState
ShareKeysMinTrust id.TrustState

Expand Down Expand Up @@ -224,7 +227,11 @@ func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string)
Str("trace_id", traceID).
Interface("changes", dl.Changed).
Msg("Device list changes in /sync")
mach.fetchKeys(context.TODO(), dl.Changed, since, false)
if mach.DisableKeyFetching {
mach.CryptoStore.MarkTrackedUsersOutdated(context.TODO(), dl.Changed)
} else {
mach.FetchKeys(context.TODO(), dl.Changed, false)
}
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
}
}
Expand Down Expand Up @@ -413,11 +420,12 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
if err != nil {
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
} else if device != nil {
} else if device != nil || mach.DisableKeyFetching {
return device, nil
}
usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true)
if devices, ok := usersToDevices[userID]; ok {
if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil {
return nil, fmt.Errorf("failed to fetch keys: %w", err)
} else if devices, ok := usersToDevices[userID]; ok {
if device, ok = devices[deviceID]; ok {
return device, nil
}
Expand All @@ -431,7 +439,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
// the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
if err != nil || deviceIdentity != nil {
if err != nil || deviceIdentity != nil || mach.DisableKeyFetching {
return deviceIdentity, err
}
mach.machOrContextLog(ctx).Debug().
Expand Down
27 changes: 25 additions & 2 deletions crypto/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,9 @@ func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, de
// PutDevices stores the device identity information for the given user ID.
func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
_, err := store.DB.Exec(ctx, "INSERT OR REPLACE INTO crypto_tracked_user (user_id, devices_outdated) VALUES ($1, FALSE)", userID)
if err != nil {
return fmt.Errorf("failed to add user to tracked users list: %w", err)
return fmt.Errorf("failed to upsert user to tracked users list: %w", err)
}

_, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
Expand Down Expand Up @@ -734,6 +734,29 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
}

// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error {
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
for _, userID := range users {
_, err := store.DB.Exec(ctx, "INSERT OR REPLACE INTO crypto_tracked_user (user_id, devices_outdated) VALUES ($1, TRUE)", userID)
if err != nil {
return fmt.Errorf("failed to upsert user to tracked users list: %w", err)
}
}

return nil
})
}

// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) {
rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE")
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
}

// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
_, err := store.DB.Exec(ctx, `
Expand Down
5 changes: 3 additions & 2 deletions crypto/sql_store_upgrade/00-latest-revision.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-- v0 -> v10: Latest revision
-- v0 -> v11: Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
Expand All @@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS crypto_message_index (
);

CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id TEXT PRIMARY KEY
user_id TEXT PRIMARY KEY,
devices_outdated BOOLEAN NOT NULL DEFAULT FALSE
);

CREATE TABLE IF NOT EXISTS crypto_device (
Expand Down
2 changes: 2 additions & 0 deletions crypto/sql_store_upgrade/11-outdated-devices.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- v11: Add devices_outdated field to crypto_tracked_user
ALTER TABLE crypto_tracked_user ADD COLUMN devices_outdated BOOLEAN NOT NULL DEFAULT FALSE;
25 changes: 25 additions & 0 deletions crypto/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ type Store interface {
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error)
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
MarkTrackedUsersOutdated(context.Context, []id.UserID) error
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
GetOutdatedTrackedUsers(context.Context) ([]id.UserID, error)

// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error
Expand Down Expand Up @@ -148,6 +152,7 @@ type MemoryStore struct {
Devices map[id.UserID]map[id.DeviceID]*id.Device
CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
OutdatedUsers map[id.UserID]struct{}
}

var _ Store = (*MemoryStore)(nil)
Expand All @@ -167,6 +172,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
Devices: make(map[id.UserID]map[id.DeviceID]*id.Device),
CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey),
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
OutdatedUsers: make(map[id.UserID]struct{}),
}
}

Expand Down Expand Up @@ -517,6 +523,25 @@ func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID)
return users[:ptr], nil
}

func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error {
gs.lock.Lock()
for _, userID := range users {
gs.OutdatedUsers[userID] = struct{}{}
}
gs.lock.Unlock()
return nil
}

func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) {
gs.lock.RLock()
users := make([]id.UserID, 0, len(gs.OutdatedUsers))
for userID := range gs.OutdatedUsers {
users = append(users, userID)
}
gs.lock.RUnlock()
return users, nil
}

func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
gs.lock.RLock()
userKeys, ok := gs.CrossSigningKeys[userID]
Expand Down
28 changes: 28 additions & 0 deletions crypto/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package crypto
import (
"context"
"database/sql"
"golang.org/x/exp/slices"
"strconv"
"testing"

Expand Down Expand Up @@ -259,3 +260,30 @@ func TestStoreDevices(t *testing.T) {
})
}
}

func TestOutdatedTrackedUsers(t *testing.T) {
stores := getCryptoStores(t)
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
users := []id.UserID{"user0", "user1", "user2"}
err := store.MarkTrackedUsersOutdated(context.TODO(), users[1:1])
if err != nil {
t.Errorf("Error marking tracked users outdated: %v", err)
}
err = store.MarkTrackedUsersOutdated(context.TODO(), users)
if err != nil {
t.Errorf("Error marking tracked users outdated: %v", err)
}
outdated, err := store.GetOutdatedTrackedUsers(context.TODO())
if err != nil {
t.Errorf("Error filtering tracked users: %v", err)
}

slices.Sort(outdated)

if !slices.Equal(outdated, users) {
t.Errorf("Expected to outdated list to be %v, got %v", users, outdated)
}
})
}
}

0 comments on commit 20de54d

Please sign in to comment.