diff --git a/check/dns-record/check.go b/check/dns-record/check.go index 23e7059..7910ee5 100644 --- a/check/dns-record/check.go +++ b/check/dns-record/check.go @@ -52,6 +52,10 @@ func New(_ check.Database, configure func(interface{}) error) (check.Check, erro check.lookup = lookupAAAA case "CNAME": check.lookup = lookupCNAME + case "MX": + check.lookup = lookupMX + case "NS": + check.lookup = lookupNS case "TXT": check.lookup = lookupTXT default: @@ -89,21 +93,25 @@ func resolver(addr string) *net.Resolver { } func lookupA(r *net.Resolver, name string) ([]string, error) { - records, err := r.LookupIP(context.Background(), "ip4", name) - - if err != nil { - return nil, err - } - return mapToStrings(records), nil + return lookupIP(r, "ip4", name) } func lookupAAAA(r *net.Resolver, name string) ([]string, error) { - records, err := r.LookupIP(context.Background(), "ip6", name) + return lookupIP(r, "ip6", name) +} + +func lookupIP(r *net.Resolver, network, name string) ([]string, error) { + ips, err := r.LookupIP(context.Background(), network, name) if err != nil { return nil, err } - return mapToStrings(records), nil + records := make([]string, len(ips)) + + for i, ip := range ips { + records[i] = ip.String() + } + return records, nil } func lookupCNAME(r *net.Resolver, name string) ([]string, error) { @@ -115,15 +123,34 @@ func lookupCNAME(r *net.Resolver, name string) ([]string, error) { return []string{record}, nil } -func lookupTXT(r *net.Resolver, name string) ([]string, error) { - return r.LookupTXT(context.Background(), name) +func lookupMX(r *net.Resolver, name string) ([]string, error) { + mxs, err := r.LookupMX(context.Background(), name) + + if err != nil { + return nil, err + } + records := make([]string, len(mxs)) + + for i, mx := range mxs { + records[i] = mx.Host + } + return records, nil } -func mapToStrings[T fmt.Stringer](values []T) []string { - s := make([]string, len(values)) +func lookupNS(r *net.Resolver, name string) ([]string, error) { + nss, err := r.LookupNS(context.Background(), name) + + if err != nil { + return nil, err + } + records := make([]string, len(nss)) - for i, val := range values { - s[i] = val.String() + for i, ns := range nss { + records[i] = ns.Host } - return s + return records, nil +} + +func lookupTXT(r *net.Resolver, name string) ([]string, error) { + return r.LookupTXT(context.Background(), name) }