diff --git a/rpcadaptors.go b/rpcadaptors.go index d629bfec2..e0fa96cd2 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -13,6 +13,7 @@ import ( "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/chaincfg/v3" + "github.com/decred/dcrd/connmgr/v3" "github.com/decred/dcrd/dcrutil/v4" "github.com/decred/dcrd/internal/blockchain" "github.com/decred/dcrd/internal/mempool" @@ -124,13 +125,77 @@ var _ rpcserver.ConnManager = (*rpcConnManager)(nil) // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) Connect(addr string, permanent bool) error { - replyChan := make(chan error) - cm.server.query <- connectNodeMsg{ - addr: addr, - permanent: permanent, - reply: replyChan, + // Prevent duplicate connections to the same peer. + connManager := cm.server.connManager + err := connManager.ForEachConnReq(func(c *connmgr.ConnReq) error { + if c.Addr != nil && c.Addr.String() == addr { + if c.Permanent { + return errors.New("peer exists as a permanent peer") + } + + switch c.State() { + case connmgr.ConnPending: + return errors.New("peer pending connection") + case connmgr.ConnEstablished: + return errors.New("peer already connected") + + } + } + return nil + }) + if err != nil { + return err + } + + netAddr, err := addrStringToNetAddr(addr) + if err != nil { + return err + } + + // Limit max number of total peers. + cm.server.peerState.Lock() + count := cm.server.peerState.count() + cm.server.peerState.Unlock() + if count >= cfg.MaxPeers { + return errors.New("max peers reached") + } + + go connManager.Connect(context.Background(), &connmgr.ConnReq{ + Addr: netAddr, + Permanent: permanent, + }) + return nil +} + +// removeNode removes any peers that the provided compare function return true +// for from the list of persistent peers. +// +// An error will be returned if no matching peers are found (aka the compare +// function returns false for all peers). +func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error { + state := &cm.server.peerState + state.Lock() + found := disconnectPeer(state.persistentPeers, cmp, func(sp *serverPeer) { + // Update the group counts since the peer will be removed from the + // persistent peers just after this func returns. + remoteAddr := wireToAddrmgrNetAddress(sp.NA()) + state.outboundGroups[remoteAddr.GroupKey()]-- + + connReq := sp.connReq.Load() + peerLog.Debugf("Removing persistent peer %s (reqid %d)", remoteAddr, + connReq.ID()) + + // Mark the peer's connReq as nil to prevent it from scheduling a + // re-connect attempt. + sp.connReq.Store(nil) + cm.server.connManager.Remove(connReq.ID()) + }) + state.Unlock() + + if !found { + return errors.New("peer not found") } - return <-replyChan + return nil } // RemoveByID removes the peer associated with the provided id from the list of @@ -140,12 +205,8 @@ func (cm *rpcConnManager) Connect(addr string, permanent bool) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) RemoveByID(id int32) error { - replyChan := make(chan error) - cm.server.query <- removeNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.ID() == id }, - reply: replyChan, - } - return <-replyChan + cmp := func(sp *serverPeer) bool { return sp.ID() == id } + return cm.removeNode(cmp) } // RemoveByAddr removes the peer associated with the provided address from the @@ -155,21 +216,54 @@ func (cm *rpcConnManager) RemoveByID(id int32) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) RemoveByAddr(addr string) error { - replyChan := make(chan error) - cm.server.query <- removeNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, - reply: replyChan, + cmp := func(sp *serverPeer) bool { return sp.Addr() == addr } + err := cm.removeNode(cmp) + if err != nil { + netAddr, err := addrStringToNetAddr(addr) + if err != nil { + return err + } + return cm.server.connManager.CancelPending(netAddr) } + return nil +} - // Cancel the connection if it could still be pending. - err := <-replyChan - if err != nil { - cm.server.query <- cancelPendingMsg{ - addr: addr, - reply: replyChan, +// disconnectNode disconnects any peers that the provided compare function +// returns true for. It applies to both inbound and outbound peers. +// +// An error will be returned if no matching peers are found (aka the compare +// function returns false for all peers). +// +// This function is safe for concurrent access. +func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error { + state := &cm.server.peerState + defer state.Unlock() + state.Lock() + + // Check inbound peers. No callback is passed since there are no additional + // actions on disconnect for inbound peers. + found := disconnectPeer(state.inboundPeers, cmp, nil) + if found { + return nil + } + + // Check outbound peers in a loop to ensure all outbound connections to the + // same ip:port are disconnected when there are multiple. + var numFound uint32 + for ; ; numFound++ { + found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) { + // Update the group counts since the peer will be removed from the + // persistent peers just after this func returns. + remoteAddr := wireToAddrmgrNetAddress(sp.NA()) + state.outboundGroups[remoteAddr.GroupKey()]-- + }) + if !found { + break } + } - return <-replyChan + if numFound == 0 { + return errors.New("peer not found") } return nil } @@ -181,12 +275,8 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) DisconnectByID(id int32) error { - replyChan := make(chan error) - cm.server.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.ID() == id }, - reply: replyChan, - } - return <-replyChan + cmp := func(sp *serverPeer) bool { return sp.ID() == id } + return cm.disconnectNode(cmp) } // DisconnectByAddr disconnects the peer associated with the provided address. @@ -196,12 +286,8 @@ func (cm *rpcConnManager) DisconnectByID(id int32) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) DisconnectByAddr(addr string) error { - replyChan := make(chan error) - cm.server.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, - reply: replyChan, - } - return <-replyChan + cmp := func(sp *serverPeer) bool { return sp.Addr() == addr } + return cm.disconnectNode(cmp) } // ConnectedCount returns the number of currently connected peers. @@ -226,15 +312,16 @@ func (cm *rpcConnManager) NetTotals() (uint64, uint64) { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) ConnectedPeers() []rpcserver.Peer { - replyChan := make(chan []*serverPeer) - cm.server.query <- getPeersMsg{reply: replyChan} - serverPeers := <-replyChan - - // Convert to RPC server peers. - peers := make([]rpcserver.Peer, 0, len(serverPeers)) - for _, sp := range serverPeers { + state := &cm.server.peerState + state.Lock() + peers := make([]rpcserver.Peer, 0, state.count()) + state.forAllPeers(func(sp *serverPeer) { + if !sp.Connected() { + return + } peers = append(peers, (*rpcPeer)(sp)) - } + }) + state.Unlock() return peers } @@ -244,15 +331,14 @@ func (cm *rpcConnManager) ConnectedPeers() []rpcserver.Peer { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) PersistentPeers() []rpcserver.Peer { - replyChan := make(chan []*serverPeer) - cm.server.query <- getAddedNodesMsg{reply: replyChan} - serverPeers := <-replyChan - - // Convert to RPC server peers. - peers := make([]rpcserver.Peer, 0, len(serverPeers)) - for _, sp := range serverPeers { + // Return a slice of the relevant peers converted to RPC server peers. + state := &cm.server.peerState + state.Lock() + peers := make([]rpcserver.Peer, 0, len(state.persistentPeers)) + for _, sp := range state.persistentPeers { peers = append(peers, (*rpcPeer)(sp)) } + state.Unlock() return peers } diff --git a/server.go b/server.go index e789ec707..05e81c002 100644 --- a/server.go +++ b/server.go @@ -329,46 +329,52 @@ func (sc *naSubmissionCache) bestSubmission(net addrmgr.NetAddressType) *naSubmi return best } -// peerState maintains state of inbound, persistent, outbound peers as well +// peerState houses state of inbound, persistent, and outbound peers as well // as banned peers and outbound groups. type peerState struct { + sync.Mutex + + // The following fields are protected by the embedded mutex. inboundPeers map[int32]*serverPeer outboundPeers map[int32]*serverPeer persistentPeers map[int32]*serverPeer banned map[string]time.Time outboundGroups map[string]int - subCache *naSubmissionCache + + // subCache houses the network address submission cache and is protected + // by its own mutex. + subCache *naSubmissionCache } -// ConnectionsWithIP returns the number of connections with the given IP. -func (ps *peerState) ConnectionsWithIP(ip net.IP) int { - var total int - for _, p := range ps.inboundPeers { - if ip.Equal(p.NA().IP) { - total++ - } - } - for _, p := range ps.outboundPeers { - if ip.Equal(p.NA().IP) { - total++ - } - } - for _, p := range ps.persistentPeers { - if ip.Equal(p.NA().IP) { - total++ - } +// makePeerState returns a peer state instance that is used to maintain the +// state of inbound, persistent, and outbound peers as well as banned peers and +// outbound groups. +func makePeerState() peerState { + return peerState{ + inboundPeers: make(map[int32]*serverPeer), + persistentPeers: make(map[int32]*serverPeer), + outboundPeers: make(map[int32]*serverPeer), + banned: make(map[string]time.Time), + outboundGroups: make(map[string]int), + subCache: &naSubmissionCache{ + cache: make(map[string]*naSubmission, maxCachedNaSubmissions), + limit: maxCachedNaSubmissions, + }, } - return total } -// Count returns the count of all known peers. -func (ps *peerState) Count() int { +// count returns the count of all known peers. +// +// This function MUST be called with the embedded mutex locked (for reads). +func (ps *peerState) count() int { return len(ps.inboundPeers) + len(ps.outboundPeers) + len(ps.persistentPeers) } // forAllOutboundPeers is a helper function that runs closure on all outbound // peers known to peerState. +// +// This function MUST be called with the embedded mutex locked (for reads). func (ps *peerState) forAllOutboundPeers(closure func(sp *serverPeer)) { for _, e := range ps.outboundPeers { closure(e) @@ -380,6 +386,8 @@ func (ps *peerState) forAllOutboundPeers(closure func(sp *serverPeer)) { // forAllPeers is a helper function that runs closure on all peers known to // peerState. +// +// This function MUST be called with the embedded mutex locked (for reads). func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) { for _, e := range ps.inboundPeers { closure(e) @@ -387,9 +395,35 @@ func (ps *peerState) forAllPeers(closure func(sp *serverPeer)) { ps.forAllOutboundPeers(closure) } +// ForAllPeers is a helper function that runs closure on all peers known to +// peerState. +// +// This function is safe for concurrent access. +func (ps *peerState) ForAllPeers(closure func(sp *serverPeer)) { + ps.Lock() + ps.forAllPeers(closure) + ps.Unlock() +} + +// connectionsWithIP returns the number of connections with the given IP. +// +// This function MUST be called with the embedded mutex locked (for reads). +func (ps *peerState) connectionsWithIP(ip net.IP) int { + var total int + ps.forAllPeers(func(sp *serverPeer) { + if ip.Equal(sp.NA().IP) { + total++ + } + + }) + return total +} + // ResolveLocalAddress picks the best suggested network address from available // options, per the network interface key provided. The best suggestion, if // found, is added as a local address. +// +// This function is safe for concurrent access. func (ps *peerState) ResolveLocalAddress(netType addrmgr.NetAddressType, addrMgr *addrmgr.AddrManager, services wire.ServiceFlag) { best := ps.subCache.bestSubmission(netType) if best == nil { @@ -518,10 +552,8 @@ type server struct { cpuMiner *cpuminer.CPUMiner mixMsgPool *mixpool.Pool modifyRebroadcastInv chan interface{} - newPeers chan *serverPeer - donePeers chan *serverPeer + peerState peerState banPeers chan *serverPeer - query chan interface{} relayInv chan relayMsg broadcast chan broadcastMsg nat *upnpNAT @@ -552,23 +584,34 @@ type server struct { type serverPeer struct { *peer.Peer - connReq *connmgr.ConnReq - server *server - persistent bool - continueHash atomic.Pointer[chainhash.Hash] - disableRelayTx atomic.Bool - isWhitelisted bool - knownAddresses *apbf.Filter - banScore connmgr.DynamicBanScore - quit chan struct{} + // These fields are set at creation time and never modified afterwards, so + // they do not need to be protected for concurrent access. + server *server + persistent bool + isWhitelisted bool + quit chan struct{} // syncMgrPeer houses the network sync manager peer instance that wraps the // underlying peer similar to the way this server peer itself wraps it. syncMgrPeer *netsync.Peer - // addrsSent, getMiningStateSent and initState all track whether or not - // the peer has already sent the respective request. It is used to - // prevent more than one response per connection. + // All fields below this point are either not set at creation time or are + // otherwise modified during operation and thus need to consider whether or + // not they need to be protected for concurrent access. + + connReq atomic.Pointer[connmgr.ConnReq] + continueHash atomic.Pointer[chainhash.Hash] + disableRelayTx atomic.Bool + knownAddresses *apbf.Filter + banScore connmgr.DynamicBanScore + + // addrsSent, getMiningStateSent and initState track whether or not the peer + // has already sent the respective request. They are used to prevent more + // than one response of each respective request per connection. + // + // They are only accessed directly in callbacks which all run in the same + // peer input handler goroutine and thus do not need to be protected for + // concurrent access. addrsSent bool getMiningStateSent bool initStateSent bool @@ -584,6 +627,9 @@ type serverPeer struct { // announcedBlock tracks the most recent block announced to this peer and is // used to filter duplicates. + // + // It is only accessed in the goroutine that handles relaying inventory and + // thus does not need to be protected for concurrent access. announcedBlock *chainhash.Hash // The following fields are used to serve getdata requests asynchronously as @@ -1934,203 +1980,6 @@ func (s *server) TransactionConfirmed(tx *dcrutil.Tx) { } } -// handleAddPeerMsg deals with adding new peers. It is invoked from the -// peerHandler goroutine. -func (s *server) handleAddPeerMsg(state *peerState, sp *serverPeer) bool { - if sp == nil { - return false - } - - // Ignore new peers if we're shutting down. - if s.shutdown.Load() { - srvrLog.Infof("New peer %s ignored - server is shutting down", sp) - sp.Disconnect() - return false - } - - // Disconnect banned peers. - host, _, err := net.SplitHostPort(sp.Addr()) - if err != nil { - srvrLog.Debugf("can't split hostport %v", err) - sp.Disconnect() - return false - } - if banEnd, ok := state.banned[host]; ok { - if time.Now().Before(banEnd) { - srvrLog.Debugf("Peer %s is banned for another %v - disconnecting", - host, time.Until(banEnd)) - sp.Disconnect() - return false - } - - srvrLog.Infof("Peer %s is no longer banned", host) - delete(state.banned, host) - } - - // Limit max number of connections from a single IP. However, allow - // whitelisted inbound peers and localhost connections regardless. - isInboundWhitelisted := sp.isWhitelisted && sp.Inbound() - peerIP := sp.NA().IP - if cfg.MaxSameIP > 0 && !isInboundWhitelisted && !peerIP.IsLoopback() && - state.ConnectionsWithIP(peerIP)+1 > cfg.MaxSameIP { - srvrLog.Infof("Max connections with %s reached [%d] - "+ - "disconnecting peer", sp, cfg.MaxSameIP) - sp.Disconnect() - return false - } - - // Limit max number of total peers. However, allow whitelisted inbound - // peers regardless. - if state.Count()+1 > cfg.MaxPeers && !isInboundWhitelisted { - srvrLog.Infof("Max peers reached [%d] - disconnecting peer %s", - cfg.MaxPeers, sp) - sp.Disconnect() - // TODO: how to handle permanent peers here? - // they should be rescheduled. - return false - } - - na := sp.peerNa.Load() - - // Add the new peer and start it. - srvrLog.Debugf("New peer %s", sp) - if sp.Inbound() { - state.inboundPeers[sp.ID()] = sp - - if na != nil { - id := na.IP.String() - - // Inbound peers can only corroborate existing address submissions. - if state.subCache.exists(id) { - err := state.subCache.incrementScore(id) - if err != nil { - srvrLog.Errorf("unable to increment submission score: %v", err) - return true - } - } - } - } else { - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]++ - if sp.persistent { - state.persistentPeers[sp.ID()] = sp - } else { - state.outboundPeers[sp.ID()] = sp - } - - // Fetch the suggested public ip from the outbound peer if - // there are no prevailing conditions to disable automatic - // network address discovery. - // - // The conditions to disable automatic network address - // discovery are: - // - If there is a proxy set (--proxy, --onion). - // - If automatic network address discovery is explicitly - // disabled (--nodiscoverip). - // - If there is an external ip explicitly set (--externalip). - // - If listening has been disabled (--nolisten, listen - // disabled because of --connect, etc). - // - If Universal Plug and Play is enabled (--upnp). - // - If the active network is simnet or regnet. - if (cfg.Proxy != "" || cfg.OnionProxy != "") || - cfg.NoDiscoverIP || len(cfg.ExternalIPs) > 0 || - (cfg.DisableListen || len(cfg.Listeners) == 0) || cfg.Upnp || - s.chainParams.Name == simNetParams.Name || - s.chainParams.Name == regNetParams.Name { - return true - } - - if na != nil { - net := addrmgr.IPv4Address - if na.IP.To4() == nil { - net = addrmgr.IPv6Address - } - - localAddr := wireToAddrmgrNetAddress(na) - valid, reach := s.addrManager.ValidatePeerNa(localAddr, remoteAddr) - if !valid { - return true - } - - id := na.IP.String() - if state.subCache.exists(id) { - // Increment the submission score if it already exists. - err := state.subCache.incrementScore(id) - if err != nil { - srvrLog.Errorf("unable to increment submission score: %v", err) - return true - } - } else { - // Create a cache entry for a new submission. - sub := &naSubmission{ - na: na, - netType: net, - reach: reach, - } - - err := state.subCache.add(sub) - if err != nil { - srvrLog.Errorf("unable to add submission: %v", err) - return true - } - } - - // Pick the local address for the provided network based on - // submission scores. - state.ResolveLocalAddress(net, s.addrManager, s.services) - } - } - - return true -} - -// handleDonePeerMsg deals with peers that have signalled they are done. It is -// invoked from the peerHandler goroutine. -func (s *server) handleDonePeerMsg(state *peerState, sp *serverPeer) { - var list map[int32]*serverPeer - if sp.persistent { - list = state.persistentPeers - } else if sp.Inbound() { - list = state.inboundPeers - } else { - list = state.outboundPeers - } - if _, ok := list[sp.ID()]; ok { - if !sp.Inbound() && sp.VersionKnown() { - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - } - if !sp.Inbound() && sp.connReq != nil { - s.connManager.Disconnect(sp.connReq.ID()) - } - delete(list, sp.ID()) - srvrLog.Debugf("Removed peer %s", sp) - return - } - - if sp.connReq != nil { - s.connManager.Disconnect(sp.connReq.ID()) - } - - // Update the address manager with the last seen time when the peer has - // acknowledged our version and has sent us its version as well. This is - // skipped when running on the simulation and regression test networks since - // they are only intended to connect to specified peers and actively avoid - // advertising and connecting to discovered peers. - if !cfg.SimNet && !cfg.RegNet && sp.VerAckReceived() && sp.VersionKnown() && - sp.NA() != nil { - - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - err := s.addrManager.Connected(remoteAddr) - if err != nil { - srvrLog.Errorf("Marking address as connected failed: %v", err) - } - } - - // If we get here it means that either we didn't know about the peer - // or we purposefully deleted it. -} - // handleBanPeerMsg deals with banning peers. It is invoked from the // peerHandler goroutine. func (s *server) handleBanPeerMsg(state *peerState, sp *serverPeer) { @@ -2142,13 +1991,16 @@ func (s *server) handleBanPeerMsg(state *peerState, sp *serverPeer) { direction := directionString(sp.Inbound()) srvrLog.Infof("Banned peer %s (%s) for %v", host, direction, cfg.BanDuration) - state.banned[host] = time.Now().Add(cfg.BanDuration) + bannedUntil := time.Now().Add(cfg.BanDuration) + state.Lock() + state.banned[host] = bannedUntil + state.Unlock() } // handleRelayInvMsg deals with relaying inventory to peers that are not already // known to have it. It is invoked from the peerHandler goroutine. func (s *server) handleRelayInvMsg(state *peerState, msg relayMsg) { - state.forAllPeers(func(sp *serverPeer) { + state.ForAllPeers(func(sp *serverPeer) { if !sp.Connected() { return } @@ -2220,7 +2072,7 @@ func (s *server) handleRelayInvMsg(state *peerState, msg relayMsg) { // handleBroadcastMsg deals with broadcasting messages to peers. It is invoked // from the peerHandler goroutine. func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) { - state.forAllPeers(func(sp *serverPeer) { + state.ForAllPeers(func(sp *serverPeer) { if !sp.Connected() { return } @@ -2235,190 +2087,6 @@ func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) { }) } -type getConnCountMsg struct { - reply chan int32 -} - -type getPeersMsg struct { - reply chan []*serverPeer -} - -type getOutboundGroup struct { - key string - reply chan int -} - -type getAddedNodesMsg struct { - reply chan []*serverPeer -} - -type disconnectNodeMsg struct { - cmp func(*serverPeer) bool - reply chan error -} - -type connectNodeMsg struct { - addr string - permanent bool - reply chan error -} - -type removeNodeMsg struct { - cmp func(*serverPeer) bool - reply chan error -} - -type cancelPendingMsg struct { - addr string - reply chan error -} - -// handleQuery is the central handler for all queries and commands from other -// goroutines related to peer state. -func (s *server) handleQuery(ctx context.Context, state *peerState, querymsg interface{}) { - switch msg := querymsg.(type) { - case getConnCountMsg: - nconnected := int32(0) - state.forAllPeers(func(sp *serverPeer) { - if sp.Connected() { - nconnected++ - } - }) - msg.reply <- nconnected - - case getPeersMsg: - peers := make([]*serverPeer, 0, state.Count()) - state.forAllPeers(func(sp *serverPeer) { - if !sp.Connected() { - return - } - peers = append(peers, sp) - }) - msg.reply <- peers - - case connectNodeMsg: - // XXX duplicate oneshots? - // Limit max number of total peers. - if state.Count() >= cfg.MaxPeers { - msg.reply <- errors.New("max peers reached") - return - } - err := s.connManager.ForEachConnReq(func(c *connmgr.ConnReq) error { - if c.Addr != nil && c.Addr.String() == msg.addr { - if c.Permanent { - return errors.New("peer exists as a permanent peer") - } - - switch c.State() { - case connmgr.ConnPending: - return errors.New("peer pending connection") - case connmgr.ConnEstablished: - return errors.New("peer already connected") - - } - } - return nil - }) - if err != nil { - msg.reply <- err - return - } - - netAddr, err := addrStringToNetAddr(msg.addr) - if err != nil { - msg.reply <- err - return - } - - // TODO: if too many, nuke a non-perm peer. - go s.connManager.Connect(ctx, - &connmgr.ConnReq{ - Addr: netAddr, - Permanent: msg.permanent, - }) - msg.reply <- nil - - case removeNodeMsg: - found := disconnectPeer(state.persistentPeers, msg.cmp, func(sp *serverPeer) { - // Keep group counts ok since we remove from - // the list now. - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - - peerLog.Debugf("Removing persistent peer %s (reqid %d)", remoteAddr, - sp.connReq.ID()) - connReq := sp.connReq - - // Mark the peer's connReq as nil to prevent it from scheduling a - // re-connect attempt. - sp.connReq = nil - s.connManager.Remove(connReq.ID()) - }) - - if found { - msg.reply <- nil - } else { - msg.reply <- errors.New("peer not found") - } - - case cancelPendingMsg: - netAddr, err := addrStringToNetAddr(msg.addr) - if err != nil { - msg.reply <- err - return - } - msg.reply <- s.connManager.CancelPending(netAddr) - - case getOutboundGroup: - count, ok := state.outboundGroups[msg.key] - if ok { - msg.reply <- count - } else { - msg.reply <- 0 - } - - case getAddedNodesMsg: - // Respond with a slice of the relevant peers. - peers := make([]*serverPeer, 0, len(state.persistentPeers)) - for _, sp := range state.persistentPeers { - peers = append(peers, sp) - } - msg.reply <- peers - - case disconnectNodeMsg: - // Check inbound peers. We pass a nil callback since we don't - // require any additional actions on disconnect for inbound peers. - found := disconnectPeer(state.inboundPeers, msg.cmp, nil) - if found { - msg.reply <- nil - return - } - - // Check outbound peers. - found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) { - // Keep group counts ok since we remove from - // the list now. - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - }) - if found { - // If there are multiple outbound connections to the same - // ip:port, continue disconnecting them all until no such - // peers are found. - for found { - found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) { - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - }) - } - msg.reply <- nil - return - } - - msg.reply <- errors.New("peer not found") - } -} - // disconnectPeer attempts to drop the connection of a targeted peer in the // passed peer list. Targets are identified via usage of the passed // `compareFunc`, which should return `true` if the passed peer is the target @@ -2533,15 +2201,16 @@ func (s *server) outboundPeerConnected(c *connmgr.ConnReq, conn net.Conn) { } sp.Peer = p sp.syncMgrPeer = netsync.NewPeer(sp.Peer) - sp.connReq = c + sp.connReq.Store(c) sp.isWhitelisted = isWhitelisted(conn.RemoteAddr()) sp.AssociateConnection(conn) go sp.Run() } -// peerHandler is used to handle peer operations such as adding and removing -// peers to and from the server, banning peers, and broadcasting messages to -// peers. It must be run in a goroutine. +// peerHandler is used to handle peer operations such as banning peers and +// broadcasting messages to peers. +// +// It must be run in a goroutine. func (s *server) peerHandler(ctx context.Context) { // Start the address manager which is needed by peers. This is done here // since its lifecycle is closely tied to this handler and rather than @@ -2551,56 +2220,27 @@ func (s *server) peerHandler(ctx context.Context) { srvrLog.Tracef("Starting peer handler") - state := &peerState{ - inboundPeers: make(map[int32]*serverPeer), - persistentPeers: make(map[int32]*serverPeer), - outboundPeers: make(map[int32]*serverPeer), - banned: make(map[string]time.Time), - outboundGroups: make(map[string]int), - subCache: &naSubmissionCache{ - cache: make(map[string]*naSubmission, maxCachedNaSubmissions), - limit: maxCachedNaSubmissions, - }, - } - out: for { select { - // New peers connected to the server. - case p := <-s.newPeers: - s.handleAddPeerMsg(state, p) - - // Signal the net sync manager this peer is a new sync candidate - // unless it was disconnected above. - if p.Connected() { - s.syncManager.PeerConnected(p.syncMgrPeer) - } - - // Disconnected peers. - case p := <-s.donePeers: - s.handleDonePeerMsg(state, p) - // Peer to ban. case p := <-s.banPeers: - s.handleBanPeerMsg(state, p) + s.handleBanPeerMsg(&s.peerState, p) // New inventory to potentially be relayed to other peers. case invMsg := <-s.relayInv: - s.handleRelayInvMsg(state, invMsg) + s.handleRelayInvMsg(&s.peerState, invMsg) // Message to broadcast to all connected peers except those // which are excluded by the message. case bmsg := <-s.broadcast: - s.handleBroadcastMsg(state, &bmsg) - - case qmsg := <-s.query: - s.handleQuery(ctx, state, qmsg) + s.handleBroadcastMsg(&s.peerState, &bmsg) case <-ctx.Done(): close(s.quit) // Disconnect all peers on server shutdown. - state.forAllPeers(func(sp *serverPeer) { + s.peerState.ForAllPeers(func(sp *serverPeer) { srvrLog.Tracef("Shutdown peer %s", sp) sp.Disconnect() }) @@ -2612,19 +2252,229 @@ out: srvrLog.Tracef("Peer handler done") } +// handleAddPeer deals with adding new peers and includes logic such as +// categorizing the type of peer, limiting the maximum allowed number of peers, +// and local external address resolution. +// +// It returns whether or not the peer was accepted by the server. +// +// This function is safe for concurrent access. +func (s *server) handleAddPeer(sp *serverPeer) bool { + // Ignore new peers when shutting down. + if s.shutdown.Load() { + srvrLog.Infof("New peer %s ignored - server is shutting down", sp) + sp.Disconnect() + return false + } + + state := &s.peerState + defer state.Unlock() + state.Lock() + + // Disconnect banned peers. + host, _, err := net.SplitHostPort(sp.Addr()) + if err != nil { + srvrLog.Debugf("can't split hostport %v", err) + sp.Disconnect() + return false + } + if banEnd, ok := state.banned[host]; ok { + if time.Now().Before(banEnd) { + srvrLog.Debugf("Peer %s is banned for another %v - disconnecting", + host, time.Until(banEnd)) + sp.Disconnect() + return false + } + + srvrLog.Infof("Peer %s is no longer banned", host) + delete(state.banned, host) + } + + // Limit max number of connections from a single IP. However, allow + // whitelisted inbound peers and localhost connections regardless. + isInboundWhitelisted := sp.isWhitelisted && sp.Inbound() + peerIP := sp.NA().IP + if cfg.MaxSameIP > 0 && !isInboundWhitelisted && !peerIP.IsLoopback() && + state.connectionsWithIP(peerIP)+1 > cfg.MaxSameIP { + + srvrLog.Infof("Max connections with %s reached [%d] - disconnecting "+ + "peer", sp, cfg.MaxSameIP) + sp.Disconnect() + return false + } + + // Limit max number of total peers. However, allow whitelisted inbound + // peers regardless. + if state.count()+1 > cfg.MaxPeers && !isInboundWhitelisted { + srvrLog.Infof("Max peers reached [%d] - disconnecting peer %s", + cfg.MaxPeers, sp) + sp.Disconnect() + // TODO: how to handle permanent peers here? + // they should be rescheduled. + return false + } + + na := sp.peerNa.Load() + + // Add the new peer. + srvrLog.Debugf("New peer %s", sp) + if sp.Inbound() { + state.inboundPeers[sp.ID()] = sp + + if na != nil { + id := na.IP.String() + + // Inbound peers can only corroborate existing address submissions. + if state.subCache.exists(id) { + err := state.subCache.incrementScore(id) + if err != nil { + srvrLog.Errorf("unable to increment submission score: %v", err) + return true + } + } + } + + return true + } + + // The peer is an outbound peer at this point. + remoteAddr := wireToAddrmgrNetAddress(sp.NA()) + state.outboundGroups[remoteAddr.GroupKey()]++ + if sp.persistent { + state.persistentPeers[sp.ID()] = sp + } else { + state.outboundPeers[sp.ID()] = sp + } + + // Fetch the suggested public IP from the outbound peer if there are no + // prevailing conditions to disable automatic network address discovery. + // + // The conditions to disable automatic network address discovery are: + // - There is a proxy set (--proxy, --onion) + // - Automatic network address discovery is explicitly disabled + // (--nodiscoverip) + // - There is an external IP explicitly set (--externalip) + // - Listening has been disabled (--nolisten, listen disabled because of + // --connect, etc) + // - Universal Plug and Play is enabled (--upnp) + // - The active network is simnet or regnet + if (cfg.Proxy != "" || cfg.OnionProxy != "") || + cfg.NoDiscoverIP || + len(cfg.ExternalIPs) > 0 || + (cfg.DisableListen || len(cfg.Listeners) == 0) || cfg.Upnp || + s.chainParams.Name == simNetParams.Name || + s.chainParams.Name == regNetParams.Name { + + return true + } + + if na != nil { + net := addrmgr.IPv4Address + if na.IP.To4() == nil { + net = addrmgr.IPv6Address + } + + localAddr := wireToAddrmgrNetAddress(na) + valid, reach := s.addrManager.ValidatePeerNa(localAddr, remoteAddr) + if !valid { + return true + } + + id := na.IP.String() + if state.subCache.exists(id) { + // Increment the submission score if it already exists. + err := state.subCache.incrementScore(id) + if err != nil { + srvrLog.Errorf("unable to increment submission score: %v", err) + return true + } + } else { + // Create a cache entry for a new submission. + sub := &naSubmission{ + na: na, + netType: net, + reach: reach, + } + + err := state.subCache.add(sub) + if err != nil { + srvrLog.Errorf("unable to add submission: %v", err) + return true + } + } + + // Pick the local address for the provided network based on + // submission scores. + state.ResolveLocalAddress(net, s.addrManager, s.services) + } + + return true +} + // AddPeer adds a new peer that has already been connected to the server. +// +// This function is safe for concurrent access. func (s *server) AddPeer(sp *serverPeer) { - select { - case <-s.quit: - case s.newPeers <- sp: + s.handleAddPeer(sp) + + // Signal the net sync manager this peer is a new sync candidate unless it + // was disconnected above. + if sp.Connected() { + s.syncManager.PeerConnected(sp.syncMgrPeer) } } -// DonePeer removes a disconnected peer from the server. +// DonePeer removes a disconnected peer from the server. It includes logic such +// as updating the peer tracking state and the last time the peer was seen. +// +// This function is safe for concurrent access. func (s *server) DonePeer(sp *serverPeer) { - select { - case <-s.quit: - case s.donePeers <- sp: + state := &s.peerState + defer state.Unlock() + state.Lock() + + var list map[int32]*serverPeer + if sp.persistent { + list = state.persistentPeers + } else if sp.Inbound() { + list = state.inboundPeers + } else { + list = state.outboundPeers + } + if _, ok := list[sp.ID()]; ok { + if !sp.Inbound() && sp.VersionKnown() { + remoteAddr := wireToAddrmgrNetAddress(sp.NA()) + state.outboundGroups[remoteAddr.GroupKey()]-- + } + if !sp.Inbound() { + connReq := sp.connReq.Load() + if connReq != nil { + s.connManager.Disconnect(connReq.ID()) + } + } + delete(list, sp.ID()) + srvrLog.Debugf("Removed peer %s", sp) + return + } + + connReq := sp.connReq.Load() + if connReq != nil { + s.connManager.Disconnect(connReq.ID()) + } + + // Update the address manager with the last seen time when the peer has + // acknowledged our version and has sent us its version as well. This is + // skipped when running on the simulation and regression test networks since + // they are only intended to connect to specified peers and actively avoid + // advertising and connecting to discovered peers. + if !cfg.SimNet && !cfg.RegNet && sp.VerAckReceived() && sp.VersionKnown() && + sp.NA() != nil { + + remoteAddr := wireToAddrmgrNetAddress(sp.NA()) + err := s.addrManager.Connected(remoteAddr) + if err != nil { + srvrLog.Errorf("Marking address as connected failed: %v", err) + } } } @@ -2678,25 +2528,22 @@ func (s *server) BroadcastMessage(msg wire.Message, exclPeers ...*serverPeer) { // ConnectedCount returns the number of currently connected peers. func (s *server) ConnectedCount() int32 { - replyChan := make(chan int32) - select { - case <-s.quit: - return 0 - case s.query <- getConnCountMsg{reply: replyChan}: - return <-replyChan - } + var numConnected int32 + s.peerState.ForAllPeers(func(sp *serverPeer) { + if sp.Connected() { + numConnected++ + } + }) + return numConnected } // OutboundGroupCount returns the number of peers connected to the given // outbound group key. func (s *server) OutboundGroupCount(key string) int { - replyChan := make(chan int) - select { - case <-s.quit: - return 0 - case s.query <- getOutboundGroup{key: key, reply: replyChan}: - return <-replyChan - } + s.peerState.Lock() + count := s.peerState.outboundGroups[key] + s.peerState.Unlock() + return count } // AddBytesSent adds the passed number of bytes to the total bytes sent counter @@ -3864,10 +3711,8 @@ func newServer(ctx context.Context, listenAddrs []string, db database.DB, s := server{ chainParams: chainParams, addrManager: amgr, - newPeers: make(chan *serverPeer, cfg.MaxPeers), - donePeers: make(chan *serverPeer, cfg.MaxPeers), + peerState: makePeerState(), banPeers: make(chan *serverPeer, cfg.MaxPeers), - query: make(chan interface{}), relayInv: make(chan relayMsg, cfg.MaxPeers), broadcast: make(chan broadcastMsg, cfg.MaxPeers), modifyRebroadcastInv: make(chan interface{}),