diff --git a/p2p/node/api.go b/p2p/node/api.go index 0aae8d190a..fd82678926 100644 --- a/p2p/node/api.go +++ b/p2p/node/api.go @@ -331,17 +331,24 @@ func (p *P2PNode) GetTrieNode(hash common.Hash, location common.Location) *trie. } func (p *P2PNode) handleBroadcast(sourcePeer peer.ID, topic string, data interface{}, nodeLocation common.Location) { - switch data.(type) { - case types.WorkObjectBlockView: - case types.WorkObjectHeaderView: - case types.Transactions: - case types.WorkObjectHeader: - default: - log.Global.Debugf("received unsupported block broadcast") - // TODO: ban the peer which sent it? + if _, ok := acceptableTypes[reflect.TypeOf(data)]; !ok { + log.Global.WithFields(log.Fields{ + "peer": sourcePeer, + "topic": topic, + "type": reflect.TypeOf(data), + }).Warn("Received unsupported broadcast") return } + switch v := data.(type) { + case types.WorkObjectHeader: + p.cacheAdd(v.Hash(), &v, nodeLocation) + case types.WorkObjectHeaderView: + p.cacheAdd(v.Hash(), &v, nodeLocation) + case types.WorkObjectBlockView: + p.cacheAdd(v.Hash(), &v, nodeLocation) + } + // If we made it here, pass the data on to the consensus backend if p.consensus != nil { p.consensus.OnNewBroadcast(sourcePeer, topic, data, nodeLocation) diff --git a/p2p/node/node.go b/p2p/node/node.go index 7a9246cee4..49f1e6ea0e 100644 --- a/p2p/node/node.go +++ b/p2p/node/node.go @@ -55,7 +55,7 @@ type P2PNode struct { requestManager requestManager.RequestManager // Caches for each type of data we may receive - cache map[string]map[string]*lru.Cache[common.Hash, interface{}] + cache map[string]map[reflect.Type]*lru.Cache[common.Hash, interface{}] // Channel to signal when to quit and shutdown quitCh chan struct{} @@ -246,13 +246,20 @@ func (p *P2PNode) Close() error { return nil } -func initializeCaches(locations []common.Location) map[string]map[string]*lru.Cache[common.Hash, interface{}] { - caches := make(map[string]map[string]*lru.Cache[common.Hash, interface{}]) +// acceptableTypes is used to filter out unsupported broadcast types +var acceptableTypes = map[reflect.Type]struct{}{ + reflect.TypeOf(types.WorkObjectHeader{}): {}, + reflect.TypeOf(types.WorkObjectBlockView{}): {}, + reflect.TypeOf(types.WorkObjectHeaderView{}): {}, + reflect.TypeOf(types.Transactions{}): {}, +} + +func initializeCaches(locations []common.Location) map[string]map[reflect.Type]*lru.Cache[common.Hash, interface{}] { + caches := make(map[string]map[reflect.Type]*lru.Cache[common.Hash, interface{}]) for _, location := range locations { - locCache := map[string]*lru.Cache[common.Hash, interface{}]{ - "blocks": createCache(c_defaultCacheSize), - "transactions": createCache(c_defaultCacheSize), - "headers": createCache(c_defaultCacheSize), + locCache := map[reflect.Type]*lru.Cache[common.Hash, interface{}]{} + for typ := range acceptableTypes { + locCache[reflect.PointerTo(typ)] = createCache(c_defaultCacheSize) } caches[location.Name()] = locCache } @@ -260,7 +267,7 @@ func initializeCaches(locations []common.Location) map[string]map[string]*lru.Ca } func createCache(size int) *lru.Cache[common.Hash, interface{}] { - cache, err := lru.New[common.Hash, interface{}](size) // Assuming a fixed size of 10 for each cache + cache, err := lru.New[common.Hash, interface{}](size) if err != nil { log.Global.Fatal("error initializing cache;", err) } @@ -274,22 +281,11 @@ func (p *P2PNode) p2pAddress() (multiaddr.Multiaddr, error) { // Helper to access the corresponding data cache func (p *P2PNode) pickCache(datatype interface{}, location common.Location) *lru.Cache[common.Hash, interface{}] { - switch datatype.(type) { - case *types.WorkObject, *types.WorkObjectHeaderView, *types.WorkObjectBlockView: - return p.cache[location.Name()]["blocks"] - case *types.Transaction: - return p.cache[location.Name()]["transactions"] - case *types.Header: - return p.cache[location.Name()]["headers"] - default: - log.Global.WithField("type", reflect.TypeOf(datatype)).Fatalf("unsupported type") - return nil - } + return p.cache[location.Name()][reflect.TypeOf(datatype)] } // Add a datagram into the corresponding cache func (p *P2PNode) cacheAdd(hash common.Hash, data interface{}, location common.Location) { - return cache := p.pickCache(data, location) cache.Add(hash, data) }