diff --git a/rpcadaptors.go b/rpcadaptors.go index d246582868..00370ef01c 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -227,6 +227,46 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error { return nil } +// disconnectNode disconnects any peers that the provided compare function +// returns true for. It applies to both inbound and outbound peers. +// +// An error will be returned if no matching peers are found (aka the compare +// function returns false for all peers). +// +// This function is safe for concurrent access. +func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error { + state := &cm.server.peerState + defer state.Unlock() + state.Lock() + + // Check inbound peers. No callback is passed since there are no additional + // actions on disconnect for inbound peers. + found := disconnectPeer(state.inboundPeers, cmp, nil) + if found { + return nil + } + + // Check outbound peers in a loop to ensure all outbound connections to the + // same ip:port are disconnected when there are multiple. + var numFound uint32 + for ; ; numFound++ { + found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) { + // Update the group counts since the peer will be removed from the + // persistent peers just after this func returns. + remoteAddr := wireToAddrmgrNetAddress(sp.NA()) + state.outboundGroups[remoteAddr.GroupKey()]-- + }) + if !found { + break + } + } + + if numFound == 0 { + return errors.New("peer not found") + } + return nil +} + // DisconnectByID disconnects the peer associated with the provided id. This // applies to both inbound and outbound peers. Attempting to remove an id that // does not exist will return an error. @@ -234,12 +274,8 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) DisconnectByID(id int32) error { - replyChan := make(chan error) - cm.server.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.ID() == id }, - reply: replyChan, - } - return <-replyChan + cmp := func(sp *serverPeer) bool { return sp.ID() == id } + return cm.disconnectNode(cmp) } // DisconnectByAddr disconnects the peer associated with the provided address. @@ -249,12 +285,8 @@ func (cm *rpcConnManager) DisconnectByID(id int32) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) DisconnectByAddr(addr string) error { - replyChan := make(chan error) - cm.server.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, - reply: replyChan, - } - return <-replyChan + cmp := func(sp *serverPeer) bool { return sp.Addr() == addr } + return cm.disconnectNode(cmp) } // ConnectedCount returns the number of currently connected peers. diff --git a/server.go b/server.go index ef730d3a84..2363b9de59 100644 --- a/server.go +++ b/server.go @@ -1928,51 +1928,9 @@ func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) { }) } -type disconnectNodeMsg struct { - cmp func(*serverPeer) bool - reply chan error -} - // handleQuery is the central handler for all queries and commands from other // goroutines related to peer state. func (s *server) handleQuery(ctx context.Context, state *peerState, querymsg interface{}) { - switch msg := querymsg.(type) { - case disconnectNodeMsg: - // Check inbound peers. We pass a nil callback since we don't - // require any additional actions on disconnect for inbound peers. - state.Lock() - found := disconnectPeer(state.inboundPeers, msg.cmp, nil) - if found { - state.Unlock() - msg.reply <- nil - return - } - - // Check outbound peers. - found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) { - // Keep group counts ok since we remove from - // the list now. - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - }) - if found { - // If there are multiple outbound connections to the same - // ip:port, continue disconnecting them all until no such - // peers are found. - for found { - found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) { - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - }) - } - state.Unlock() - msg.reply <- nil - return - } - state.Unlock() - - msg.reply <- errors.New("peer not found") - } } // disconnectPeer attempts to drop the connection of a targeted peer in the