Skip to content

Commit

Permalink
[pseudoID] More pseudo ID fixes (#3167)
Browse files Browse the repository at this point in the history
Signed-off-by: `Sam Wedgwood <sam@wedgwood.dev>`
  • Loading branch information
swedgwood authored Aug 15, 2023
1 parent fa6c7ba commit 9a12420
Show file tree
Hide file tree
Showing 24 changed files with 472 additions and 237 deletions.
31 changes: 22 additions & 9 deletions clientapi/routing/joined_rooms.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,36 @@ func GetJoinedRooms(
device *userapi.Device,
rsAPI api.ClientRoomserverAPI,
) util.JSONResponse {
var res api.QueryRoomsForUserResponse
err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("Invalid device user ID")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}

rooms, err := rsAPI.QueryRoomsForUser(req.Context(), *deviceUserID, "join")
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
JSON: spec.Unknown("internal server error"),
}
}
if res.RoomIDs == nil {
res.RoomIDs = []string{}

var roomIDStrs []string
if rooms == nil {
roomIDStrs = []string{}
} else {
roomIDStrs = make([]string, len(rooms))
for i, roomID := range rooms {
roomIDStrs[i] = roomID.String()
}
}

return util.JSONResponse{
Code: http.StatusOK,
JSON: getJoinedRoomsResponse{res.RoomIDs},
JSON: getJoinedRoomsResponse{roomIDStrs},
}
}
21 changes: 15 additions & 6 deletions clientapi/routing/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,15 @@ func updateProfile(
profile *authtypes.Profile,
userID string, evTime time.Time,
) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse
err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}, err
}

rooms, err := rsAPI.QueryRoomsForUser(ctx, *deviceUserID, "join")
if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{
Expand All @@ -264,6 +268,11 @@ func updateProfile(
}, err
}

roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}

_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
Expand All @@ -274,7 +283,7 @@ func updateProfile(
}

events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, evTime, rsAPI,
ctx, roomIDStrs, *profile, userID, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
Expand Down
13 changes: 10 additions & 3 deletions clientapi/routing/sendevent.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,17 @@ func generateSendEvent(
}
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil || senderID == nil {
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("Unable to find senderID for user"),
Code: http.StatusInternalServerError,
JSON: spec.NotFound("internal server error"),
}
} else if senderID == nil {
// TODO: is it always the case that lack of a sender ID means they're not joined?
// And should this logic be deferred to the roomserver somehow?
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("not joined to room"),
}
}

Expand Down
42 changes: 25 additions & 17 deletions clientapi/routing/server_notices.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,34 +94,42 @@ func SendServerNotice(
}
}

userID, err := spec.NewUserID(r.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid user ID"),
}
}

// get rooms for specified user
allUserRooms := []string{}
userRooms := api.QueryRoomsForUserResponse{}
allUserRooms := []spec.RoomID{}
// Get rooms the user is either joined, invited or has left.
for _, membership := range []string{"join", "invite", "leave"} {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: r.UserID,
WantMembership: membership,
}, &userRooms); err != nil {
userRooms, queryErr := rsAPI.QueryRoomsForUser(ctx, *userID, membership)
if queryErr != nil {
return util.ErrorResponse(err)
}
allUserRooms = append(allUserRooms, userRooms.RoomIDs...)
allUserRooms = append(allUserRooms, userRooms...)
}

// get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName)
senderRooms := api.QueryRoomsForUserResponse{}
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: senderUserID,
WantMembership: "join",
}, &senderRooms); err != nil {
senderUserID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName), true)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
senderRooms, err := rsAPI.QueryRoomsForUser(ctx, *senderUserID, "join")
if err != nil {
return util.ErrorResponse(err)
}

// check if we have rooms in common
commonRooms := []string{}
commonRooms := []spec.RoomID{}
for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs {
for _, senderRoomID := range senderRooms {
if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID)
}
Expand All @@ -139,7 +147,7 @@ func SendServerNotice(

// create a new room for the user
if len(commonRooms) == 0 {
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID)
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID.String())
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent)
if err != nil {
Expand Down Expand Up @@ -195,7 +203,7 @@ func SendServerNotice(
}
}

roomID = commonRooms[0]
roomID = commonRooms[0].String()
membershipRes := api.QueryMembershipForUserResponse{}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
if err != nil {
Expand Down
41 changes: 29 additions & 12 deletions federationapi/consumers/keychange.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,27 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
return true
}

var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{
UserID: m.UserID,
WantMembership: "join",
}, &queryRes)
userID, err := spec.NewUserID(m.UserID, true)
if err != nil {
sentry.CaptureException(err)
logger.WithError(err).Error("invalid user ID")
return true
}

roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *userID, "join")
if err != nil {
sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined rooms for user")
return true
}

roomIDStrs := make([]string, len(roomIDs))
for i, room := range roomIDs {
roomIDStrs[i] = room.String()
}

// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil {
sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
Expand Down Expand Up @@ -179,18 +187,27 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
}
logger := logrus.WithField("user_id", output.UserID)

var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{
UserID: output.UserID,
WantMembership: "join",
}, &queryRes)
outputUserID, err := spec.NewUserID(output.UserID, true)
if err != nil {
sentry.CaptureException(err)
logrus.WithError(err).Errorf("invalid user ID")
return true
}

rooms, err := t.rsAPI.QueryRoomsForUser(t.ctx, *outputUserID, "join")
if err != nil {
sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user")
return true
}

roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}

// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil {
sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
Expand Down
20 changes: 14 additions & 6 deletions federationapi/consumers/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -94,16 +95,23 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
return true
}

var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{
UserID: userID,
WantMembership: "join",
}, &queryRes)
parsedUserID, err := spec.NewUserID(userID, true)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("user_id", userID).Error("invalid user ID")
return true
}

roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *parsedUserID, "join")
if err != nil {
log.WithError(err).Error("failed to calculate joined rooms for user")
return true
}

roomIDStrs := make([]string, len(roomIDs))
for i, roomID := range roomIDs {
roomIDStrs[i] = roomID.String()
}

presence := msg.Header.Get("presence")

ts, err := strconv.Atoi(msg.Header.Get("last_active_ts"))
Expand All @@ -112,7 +120,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
}

// send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil {
log.WithError(err).Error("failed to get joined hosts")
return true
Expand Down
22 changes: 13 additions & 9 deletions federationapi/federationapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
queryRoomsForUser func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
}

func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
Expand All @@ -54,11 +54,11 @@ func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.Input
}

// keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if f.queryRoomsForUser == nil {
return nil
return nil, nil
}
return f.queryRoomsForUser(ctx, req, res)
return f.queryRoomsForUser(ctx, userID, desiredMembership)
}

// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
Expand Down Expand Up @@ -199,18 +199,22 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator)

roomID, err := spec.NewRoomID(room.ID)
if err != nil {
t.Fatalf("Invalid room ID: %q", roomID)
}

rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
}
},
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if req.UserID == joiningUser.ID && req.WantMembership == "join" {
res.RoomIDs = []string{room.ID}
return nil
queryRoomsForUser: func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if userID.String() == joiningUser.ID && desiredMembership == "join" {
return []spec.RoomID{*roomID}, nil
}
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
return nil, fmt.Errorf("unexpected queryRoomsForUser: %v, %v", userID, desiredMembership)
},
}
fc := &fedClient{
Expand Down
Loading

0 comments on commit 9a12420

Please sign in to comment.