Skip to content

Commit

Permalink
bridgev2: add method for getting all portals with Matrix room
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jul 14, 2024
1 parent edf1a8d commit fb9fb5a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
5 changes: 5 additions & 0 deletions bridgev2/database/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const (
getPortalByKeyQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3`
getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')`
getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL`
getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2`

findPortalReceiverQuery = `SELECT id, receiver FROM portal WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='') LIMIT 1`
Expand Down Expand Up @@ -119,6 +120,10 @@ func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal,
return pq.QueryOne(ctx, getPortalByMXIDQuery, pq.BridgeID, mxid)
}

func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID)
}

func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.PortalID) ([]*Portal, error) {
return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID)
}
Expand Down
27 changes: 27 additions & 0 deletions bridgev2/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ func (portal *Portal) updateLogger() {
portal.Log = logWith.Logger()
}

func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Portal) ([]*Portal, error) {
output := make([]*Portal, 0, len(portals))
for _, dbPortal := range portals {
if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok {
output = append(output, cached)
} else {
loaded, err := br.loadPortal(ctx, dbPortal, nil, nil)
if err != nil {
return nil, err
} else if loaded != nil {
output = append(output, loaded)
}
}
}
return output, nil
}

func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) {
cached, ok := br.portalsByKey[key]
if ok {
Expand Down Expand Up @@ -172,6 +189,16 @@ func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal,
return br.loadPortal(ctx, db, err, nil)
}

func (br *Bridge) GetAllPortalsWithMXID(ctx context.Context) ([]*Portal, error) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
rows, err := br.DB.Portal.GetAllWithMXID(ctx)
if err != nil {
return nil, err
}
return br.loadManyPortals(ctx, rows)
}

func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
Expand Down

0 comments on commit fb9fb5a

Please sign in to comment.