Skip to content

Commit

Permalink
Add FetchMemberships function
Browse files Browse the repository at this point in the history
Pulled out of #329.
  • Loading branch information
David Robertson committed Nov 2, 2023
1 parent 13d4e02 commit f0ea7cb
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
51 changes: 51 additions & 0 deletions state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,57 @@ func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error {
return nil
}

// FetchMemberships looks up the latest snapshot for the given room and determines the
// latest membership events in the room. Returns
// - the list of joined members,
// - the list of invited members, and then
// - the list of all other memberships. (This is called "leaves", but includes bans. It
// also includes knocks, but the proxy doesn't support those.)
//
// Each lists' members are arranged in no particular order.
//
// TODO: there is a very similar query in ResetMetadataState which also selects events
// events row for memberships. It is a shame to have to do this twice---can we query
// once and pass the data around?
func (s *Storage) FetchMemberships(roomID string) (joins, invites, leaves []string, err error) {
var events []Event
err = s.DB.Select(&events, `
WITH snapshot(membership_nids) AS (
SELECT membership_events
FROM syncv3_snapshots
JOIN syncv3_rooms ON snapshot_id = current_snapshot_id
WHERE syncv3_rooms.room_id = $1
)
SELECT state_key, membership
FROM syncv3_events JOIN snapshot ON (
event_nid = ANY( membership_nids )
)
`, roomID)
if err != nil {
return nil, nil, nil, err
}

joins = make([]string, 0, len(events))
invites = make([]string, 0, len(events))
leaves = make([]string, 0, len(events))

for _, e := range events {
switch e.Membership {
case "_join":
fallthrough
case "join":
joins = append(joins, e.StateKey)
case "_invite":
fallthrough
case "invite":
invites = append(invites, e.StateKey)
default:
leaves = append(leaves, e.StateKey)
}
}
return
}

// Returns all current NOT MEMBERSHIP state events matching the event types given in all rooms. Returns a map of
// room ID to events in that room.
func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventTypes []string) (map[string][]Event, error) {
Expand Down
43 changes: 43 additions & 0 deletions state/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,49 @@ func TestCircularSlice(t *testing.T) {

}

func TestStorage_FetchMemberships(t *testing.T) {
assertNoError(t, cleanDB(t))
store := NewStorage(postgresConnectionString)
defer store.Teardown()

events := []json.RawMessage{
testutils.NewStateEvent(t, "m.room.create", "", "@alice:test", map[string]any{}),
testutils.NewStateEvent(t, "m.room.member", "@alice:test", "@alice:test", map[string]any{"membership": "join"}),
testutils.NewStateEvent(t, "m.room.member", "@brian:test", "@alice:test", map[string]any{"membership": "invite"}),
testutils.NewStateEvent(t, "m.room.member", "@chris:test", "@chris:test", map[string]any{"membership": "leave"}),
testutils.NewStateEvent(t, "m.room.member", "@david:test", "@alice:test", map[string]any{"membership": "ban"}),
testutils.NewStateEvent(t, "m.room.member", "@erika:test", "@erika:test", map[string]any{"membership": "join"}),
testutils.NewStateEvent(t, "m.room.member", "@frank:test", "@erika:test", map[string]any{"membership": "invite"}),
testutils.NewStateEvent(t, "m.room.member", "@glory:test", "@glory:test", map[string]any{"membership": "leave"}),
testutils.NewStateEvent(t, "m.room.member", "@helen:test", "@alice:test", map[string]any{"membership": "ban"}),
}

const roomID = "!unimportant"
err := sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) (err error) {
_, err = store.Accumulator.Initialise(roomID, events)
return err
})
assertNoError(t, err)

joins, invites, leaves, err := store.FetchMemberships(roomID)
assertNoError(t, err)

// Do not assume an order from the DB.
sort.Slice(joins, func(i, j int) bool {
return joins[i] < joins[j]
})
sort.Slice(invites, func(i, j int) bool {
return invites[i] < invites[j]
})
sort.Slice(leaves, func(i, j int) bool {
return leaves[i] < leaves[j]
})

assertValue(t, "joins", joins, []string{"@alice:test", "@erika:test"})
assertValue(t, "invites", invites, []string{"@brian:test", "@frank:test"})
assertValue(t, "joins", leaves, []string{"@chris:test", "@david:test", "@glory:test", "@helen:test"})
}

func cleanDB(t *testing.T) error {
// make a fresh DB which is unpolluted from other tests
db, close := connectToDB(t)
Expand Down

0 comments on commit f0ea7cb

Please sign in to comment.