Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: Implement separate mutex for peer state. #3251

Merged
merged 13 commits into from
May 16, 2024
Merged
186 changes: 136 additions & 50 deletions rpcadaptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/decred/dcrd/chaincfg/chainhash"
"github.com/decred/dcrd/chaincfg/v3"
"github.com/decred/dcrd/connmgr/v3"
"github.com/decred/dcrd/dcrutil/v4"
"github.com/decred/dcrd/internal/blockchain"
"github.com/decred/dcrd/internal/mempool"
Expand Down Expand Up @@ -124,13 +125,77 @@ var _ rpcserver.ConnManager = (*rpcConnManager)(nil)
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) Connect(addr string, permanent bool) error {
replyChan := make(chan error)
cm.server.query <- connectNodeMsg{
addr: addr,
permanent: permanent,
reply: replyChan,
// Prevent duplicate connections to the same peer.
connManager := cm.server.connManager
err := connManager.ForEachConnReq(func(c *connmgr.ConnReq) error {
if c.Addr != nil && c.Addr.String() == addr {
if c.Permanent {
return errors.New("peer exists as a permanent peer")
}

switch c.State() {
case connmgr.ConnPending:
return errors.New("peer pending connection")
case connmgr.ConnEstablished:
return errors.New("peer already connected")

}
}
return nil
})
if err != nil {
return err
}

netAddr, err := addrStringToNetAddr(addr)
if err != nil {
return err
}

// Limit max number of total peers.
cm.server.peerState.Lock()
count := cm.server.peerState.count()
cm.server.peerState.Unlock()
if count >= cfg.MaxPeers {
return errors.New("max peers reached")
}

go connManager.Connect(context.Background(), &connmgr.ConnReq{
Addr: netAddr,
Permanent: permanent,
})
return nil
}

// removeNode removes any peers that the provided compare function return true
// for from the list of persistent peers.
//
// An error will be returned if no matching peers are found (aka the compare
// function returns false for all peers).
func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error {
state := &cm.server.peerState
state.Lock()
found := disconnectPeer(state.persistentPeers, cmp, func(sp *serverPeer) {
// Update the group counts since the peer will be removed from the
// persistent peers just after this func returns.
remoteAddr := wireToAddrmgrNetAddress(sp.NA())
state.outboundGroups[remoteAddr.GroupKey()]--

connReq := sp.connReq.Load()
peerLog.Debugf("Removing persistent peer %s (reqid %d)", remoteAddr,
connReq.ID())

// Mark the peer's connReq as nil to prevent it from scheduling a
// re-connect attempt.
sp.connReq.Store(nil)
cm.server.connManager.Remove(connReq.ID())
})
state.Unlock()

if !found {
return errors.New("peer not found")
}
return <-replyChan
return nil
}

// RemoveByID removes the peer associated with the provided id from the list of
Expand All @@ -140,12 +205,8 @@ func (cm *rpcConnManager) Connect(addr string, permanent bool) error {
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) RemoveByID(id int32) error {
replyChan := make(chan error)
cm.server.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
cmp := func(sp *serverPeer) bool { return sp.ID() == id }
return cm.removeNode(cmp)
}

// RemoveByAddr removes the peer associated with the provided address from the
Expand All @@ -155,21 +216,54 @@ func (cm *rpcConnManager) RemoveByID(id int32) error {
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) RemoveByAddr(addr string) error {
replyChan := make(chan error)
cm.server.query <- removeNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
cmp := func(sp *serverPeer) bool { return sp.Addr() == addr }
err := cm.removeNode(cmp)
if err != nil {
netAddr, err := addrStringToNetAddr(addr)
if err != nil {
return err
}
return cm.server.connManager.CancelPending(netAddr)
davecgh marked this conversation as resolved.
Show resolved Hide resolved
}
return nil
}

// Cancel the connection if it could still be pending.
err := <-replyChan
if err != nil {
cm.server.query <- cancelPendingMsg{
addr: addr,
reply: replyChan,
// disconnectNode disconnects any peers that the provided compare function
// returns true for. It applies to both inbound and outbound peers.
//
// An error will be returned if no matching peers are found (aka the compare
// function returns false for all peers).
//
// This function is safe for concurrent access.
func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error {
state := &cm.server.peerState
defer state.Unlock()
state.Lock()

// Check inbound peers. No callback is passed since there are no additional
// actions on disconnect for inbound peers.
found := disconnectPeer(state.inboundPeers, cmp, nil)
if found {
return nil
davecgh marked this conversation as resolved.
Show resolved Hide resolved
}

// Check outbound peers in a loop to ensure all outbound connections to the
// same ip:port are disconnected when there are multiple.
var numFound uint32
for ; ; numFound++ {
found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) {
// Update the group counts since the peer will be removed from the
// persistent peers just after this func returns.
remoteAddr := wireToAddrmgrNetAddress(sp.NA())
state.outboundGroups[remoteAddr.GroupKey()]--
})
if !found {
break
}
}

return <-replyChan
if numFound == 0 {
return errors.New("peer not found")
}
return nil
}
Expand All @@ -181,12 +275,8 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error {
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) DisconnectByID(id int32) error {
replyChan := make(chan error)
cm.server.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.ID() == id },
reply: replyChan,
}
return <-replyChan
cmp := func(sp *serverPeer) bool { return sp.ID() == id }
return cm.disconnectNode(cmp)
}

// DisconnectByAddr disconnects the peer associated with the provided address.
Expand All @@ -196,12 +286,8 @@ func (cm *rpcConnManager) DisconnectByID(id int32) error {
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) DisconnectByAddr(addr string) error {
replyChan := make(chan error)
cm.server.query <- disconnectNodeMsg{
cmp: func(sp *serverPeer) bool { return sp.Addr() == addr },
reply: replyChan,
}
return <-replyChan
cmp := func(sp *serverPeer) bool { return sp.Addr() == addr }
return cm.disconnectNode(cmp)
}

// ConnectedCount returns the number of currently connected peers.
Expand All @@ -226,15 +312,16 @@ func (cm *rpcConnManager) NetTotals() (uint64, uint64) {
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) ConnectedPeers() []rpcserver.Peer {
replyChan := make(chan []*serverPeer)
cm.server.query <- getPeersMsg{reply: replyChan}
serverPeers := <-replyChan

// Convert to RPC server peers.
peers := make([]rpcserver.Peer, 0, len(serverPeers))
for _, sp := range serverPeers {
state := &cm.server.peerState
state.Lock()
peers := make([]rpcserver.Peer, 0, state.count())
state.forAllPeers(func(sp *serverPeer) {
if !sp.Connected() {
return
}
peers = append(peers, (*rpcPeer)(sp))
}
})
state.Unlock()
return peers
}

Expand All @@ -244,15 +331,14 @@ func (cm *rpcConnManager) ConnectedPeers() []rpcserver.Peer {
// This function is safe for concurrent access and is part of the
// rpcserver.ConnManager interface implementation.
func (cm *rpcConnManager) PersistentPeers() []rpcserver.Peer {
replyChan := make(chan []*serverPeer)
cm.server.query <- getAddedNodesMsg{reply: replyChan}
serverPeers := <-replyChan

// Convert to RPC server peers.
peers := make([]rpcserver.Peer, 0, len(serverPeers))
for _, sp := range serverPeers {
// Return a slice of the relevant peers converted to RPC server peers.
state := &cm.server.peerState
state.Lock()
peers := make([]rpcserver.Peer, 0, len(state.persistentPeers))
for _, sp := range state.persistentPeers {
peers = append(peers, (*rpcPeer)(sp))
}
state.Unlock()
return peers
}

Expand Down
Loading
Loading