Skip to content

Commit

Permalink
server: Make manual peer disconnect synchronous.
Browse files Browse the repository at this point in the history
This refactors the logic for manually disconnecting peers out of the
peer handler since it no longer needs to be plumbed through the query
channel.

This is a part of the overall effort to convert all of the code related
to updating and querying the server's peer state to synchronous code
that makes use of a separate mutex to protect it.
  • Loading branch information
davecgh committed May 9, 2024
1 parent a10ec87 commit c978cb8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 54 deletions.
56 changes: 44 additions & 12 deletions rpcadaptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,55 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error {
return nil
}

// disconnectNode disconnects any peers that the provided compare function
// returns true for. It applies to both inbound and outbound peers.
//
// An error will be returned if no matching peers are found (aka the compare
// function returns false for all peers).
//
// This function is safe for concurrent access.
func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error {
state := &cm.server.peerState
defer state.Unlock()
state.Lock()

// Check inbound peers. No callback is passed since there are no additional
// actions on disconnect for inbound peers.
found := disconnectPeer(state.inboundPeers, cmp, nil)
if found {
return nil
}

// Check outbound peers in a loop to ensure all outbound connections to the
// same ip:port are disconnected when there are multiple.
var numFound uint32
for ; ; numFound++ {
found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) {
// Update the group counts since the peer will be removed from the
// persistent peers just after this func returns.
remoteAddr := wireToAddrmgrNetAddress(sp.NA())
state.outboundGroups[remoteAddr.GroupKey()]--
})
if !found {
break
}
}

if numFound == 0 {
return errors.New("peer not found")
}
return nil
}

// DisconnectByID disconnects the peer associated with the provided id. This
// applies to both inbound and outbound peers. Attempting to remove an id that
// does not exist will return an error.
//
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) DisconnectByID(id int32) error {
replyChan := make(chan error)
cm.server.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
cmp := func(sp *serverPeer) bool { return sp.ID() == id }
return cm.disconnectNode(cmp)
}

// DisconnectByAddr disconnects the peer associated with the provided address.
Expand All @@ -249,12 +285,8 @@ func (cm *rpcConnManager) DisconnectByID(id int32) error {
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) DisconnectByAddr(addr string) error {
replyChan := make(chan error)
cm.server.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
cmp := func(sp *serverPeer) bool { return sp.Addr() == addr }
return cm.disconnectNode(cmp)
}

// ConnectedCount returns the number of currently connected peers.
Expand Down
42 changes: 0 additions & 42 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1928,51 +1928,9 @@ func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) {
})
}

type disconnectNodeMsg struct {
cmp func(*serverPeer) bool
reply chan error
}

// handleQuery is the central handler for all queries and commands from other
// goroutines related to peer state.
func (s *server) handleQuery(ctx context.Context, state *peerState, querymsg interface{}) {
switch msg := querymsg.(type) {
case disconnectNodeMsg:
// Check inbound peers. We pass a nil callback since we don't
// require any additional actions on disconnect for inbound peers.
state.Lock()
found := disconnectPeer(state.inboundPeers, msg.cmp, nil)
if found {
state.Unlock()
msg.reply <- nil
return
}

// Check outbound peers.
found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) {
// Keep group counts ok since we remove from
// the list now.
remoteAddr := wireToAddrmgrNetAddress(sp.NA())
state.outboundGroups[remoteAddr.GroupKey()]--
})
if found {
// If there are multiple outbound connections to the same
// ip:port, continue disconnecting them all until no such
// peers are found.
for found {
found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) {
remoteAddr := wireToAddrmgrNetAddress(sp.NA())
state.outboundGroups[remoteAddr.GroupKey()]--
})
}
state.Unlock()
msg.reply <- nil
return
}
state.Unlock()

msg.reply <- errors.New("peer not found")
}
}

// disconnectPeer attempts to drop the connection of a targeted peer in the
Expand Down

0 comments on commit c978cb8

Please sign in to comment.