diff --git a/blocklist.go b/blocklist.go index eac58ad6..f127ff42 100644 --- a/blocklist.go +++ b/blocklist.go @@ -2,6 +2,7 @@ package zgrab2 import ( "bufio" + "fmt" "io" "log" "net" @@ -45,6 +46,32 @@ func (t *RadixTree) Insert(cidr *net.IPNet) { node.isLeaf = true } +func (t *RadixTree) InsertEntry(entry string) error { + // Parse the entry + ip, cidr, err := net.ParseCIDR(entry) + if err == nil { + // Entry is a CIDR range + t.Insert(cidr) + return nil + } + + // If not CIDR, treat it as a single IP + ip = net.ParseIP(entry) + if ip == nil { + return fmt.Errorf("invalid IP or CIDR: %s", entry) + } + + // Create a /32 or /128 mask for the single IP + bits := net.IPv4len * 8 // Default to IPv4 + if ip.To4() == nil { + bits = net.IPv6len * 8 + } + mask := net.CIDRMask(bits, bits) + cidr = &net.IPNet{IP: ip, Mask: mask} + t.Insert(cidr) + return nil +} + // Contains checks if an IP is covered by any CIDR range in the tree func (t *RadixTree) Contains(ip net.IP) bool { node := t.root @@ -105,12 +132,10 @@ func LoadBlocklist(source io.Reader) (*RadixTree, error) { continue } - _, cidr, err := net.ParseCIDR(txt) - if err != nil { - log.Printf("Invalid CIDR in blocklist: %s", scanner.Text()) + if err := tree.InsertEntry(txt); err != nil { + log.Printf("Invalid blocklist entry %s: %w", txt, err) continue } - tree.Insert(cidr) } if scanner.Err() != nil { return nil, scanner.Err()