diff --git a/mxresolv/mxresolv.go b/mxresolv/mxresolv.go index 77bc3fa..da9b3f0 100644 --- a/mxresolv/mxresolv.go +++ b/mxresolv/mxresolv.go @@ -39,54 +39,78 @@ func init() { lookupResultCache = collections.NewLRUCache(cacheSize) } -// Lookup performs a DNS lookup of MX records for the specified hostname. It +// Lookup performs a DNS lookup of MX records for the specified domain. It // returns a prioritised list of MX hostnames, where hostnames with the same // priority are shuffled. If the second returned value is true, then the host // does not have explicit MX records, and its A record is returned instead. // // It uses an LRU cache with a timeout to reduce the number of network requests. -func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImplicit bool, reterr error) { - if cachedVal, ok := lookupResultCache.Get(hostname); ok { +func Lookup(ctx context.Context, domain string) (mxHosts []string, implicit bool, err error) { + mxRecords, implicit, err := LookupWithPref(ctx, domain) + if err != nil { + return nil, false, err + } + if len(mxRecords) == 1 { + return []string{mxRecords[0].Host}, implicit, err + } + return shuffleMXRecords(mxRecords), false, nil +} + +// LookupWithPref performs a DNS lookup of MX records for the specified domain. +// It returns a slice of net.MX records that are ordered by preference. Records +// with the same preference are sorted by hostname to ensure deterministic +// behaviour. If the second returned value is true, then the host does not have +// explicit MX records, and its A record is used instead. +// +// It uses an LRU cache with a timeout to reduce the number of network requests. +func LookupWithPref(ctx context.Context, domainName string) (mxRecords []*net.MX, implicit bool, err error) { + if cachedVal, ok := lookupResultCache.Get(domainName); ok { cachedLookupResult := cachedVal.(lookupResult) - if cachedLookupResult.shuffled { - reshuffledMXHosts, _ := shuffleMXRecords(cachedLookupResult.mxRecords) - return reshuffledMXHosts, cachedLookupResult.implicit, cachedLookupResult.err - } - return cachedLookupResult.mxHosts, cachedLookupResult.implicit, cachedLookupResult.err + return cachedLookupResult.mxRecords, cachedLookupResult.implicit, cachedLookupResult.err } - asciiHostname, err := ensureASCII(hostname) + asciiDomainName, err := ensureASCII(domainName) if err != nil { - return nil, false, errors.Wrap(err, "invalid hostname") + return nil, false, errors.Wrap(err, "invalid domain name") } - mxRecords, err := lookupMX(Resolver, ctx, asciiHostname) + mxRecords, err = lookupMX(Resolver, ctx, asciiDomainName) if err != nil { - var timeouter interface{ Timeout() bool } - if errors.As(err, &timeouter) && timeouter.Timeout() { - return nil, false, errors.WithStack(err) - } var netDNSError *net.DNSError if errors.As(err, &netDNSError) && netDNSError.IsNotFound { - if _, err := Resolver.LookupIPAddr(ctx, asciiHostname); err != nil { - return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err)) + if _, err := Resolver.LookupIPAddr(ctx, asciiDomainName); err != nil { + return cacheAndReturn(domainName, nil, false, errors.WithStack(err)) } - return cacheAndReturn(hostname, []string{asciiHostname}, nil, false, true, nil) + return cacheAndReturn(domainName, []*net.MX{{Host: asciiDomainName, Pref: 1}}, true, nil) } if mxRecords == nil { - return cacheAndReturn(hostname, nil, nil, false, false, errors.WithStack(err)) + return cacheAndReturn(domainName, nil, false, errors.WithStack(err)) } } // Check for "Null MX" record (https://tools.ietf.org/html/rfc7505). if len(mxRecords) == 1 { if mxRecords[0].Host == "." { - return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord) + return cacheAndReturn(domainName, nil, false, errNullMXRecord) } // 0.0.0.0 is not really a "Null MX" record, but some people apparently // have never heard of RFC7505 and configure it this way. if strings.HasPrefix(mxRecords[0].Host, "0.0.0.0") { - return cacheAndReturn(hostname, nil, nil, false, false, errNullMXRecord) + return cacheAndReturn(domainName, nil, false, errNullMXRecord) } } + // Purge records with non-ASCII characters. we have seen such records in + // production, they are obviously products of human errors. + for i := 0; i < len(mxRecords); { + if isASCII(mxRecords[i].Host) { + i++ + continue + } + copy(mxRecords[i:], mxRecords[i+1:]) + mxRecords = mxRecords[:len(mxRecords)-1] + } + // If there are no valid records left, then return an error. + if len(mxRecords) == 0 { + return cacheAndReturn(domainName, nil, false, errNoValidMXHosts) + } // Normalize returned hostnames: drop trailing '.' and lowercase. for _, mxRecord := range mxRecords { lastCharIndex := len(mxRecord.Host) - 1 @@ -100,11 +124,7 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli return mxRecords[i].Pref < mxRecords[j].Pref || (mxRecords[i].Pref == mxRecords[j].Pref && mxRecords[i].Host < mxRecords[j].Host) }) - mxHosts, shuffled := shuffleMXRecords(mxRecords) - if len(mxHosts) == 0 { - return cacheAndReturn(hostname, nil, nil, false, false, errNoValidMXHosts) - } - return cacheAndReturn(hostname, mxHosts, mxRecords, shuffled, false, nil) + return cacheAndReturn(domainName, mxRecords, false, nil) } // SetDeterministicInTests sets rand to deterministic seed for testing, and is @@ -126,14 +146,13 @@ func ResetCache() { lookupResultCache = collections.NewLRUCache(1000) } -func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) { +func shuffleMXRecords(mxRecords []*net.MX) []string { // Shuffle the hosts within the preference groups. var ( mxHosts []string groupBegin = 0 groupEnd = 0 groupPref uint16 - shuffled = false ) for _, mxRecord := range mxRecords { // If a hostname has non-ASCII characters then ignore it, for it is @@ -165,7 +184,6 @@ func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) { // After finding the end of the current preference group, shuffle it. if groupEnd-groupBegin > 1 { shuffleHosts(mxHosts[groupBegin:groupEnd]) - shuffled = true } // Set up the next preference group. groupBegin = groupEnd @@ -175,9 +193,8 @@ func shuffleMXRecords(mxRecords []*net.MX) ([]string, bool) { // Shuffle the last preference group, if there is one. if groupEnd-groupBegin > 1 { shuffleHosts(mxHosts[groupBegin:groupEnd]) - shuffled = true } - return mxHosts, shuffled + return mxHosts } func shuffleHosts(hosts []string) { @@ -208,15 +225,13 @@ func isASCII(s string) bool { type lookupResult struct { mxRecords []*net.MX - mxHosts []string - shuffled bool implicit bool err error } -func cacheAndReturn(hostname string, mxHosts []string, mxRecords []*net.MX, shuffled, implicit bool, err error) (retMxHosts []string, retImplicit bool, reterr error) { - lookupResultCache.AddWithTTL(hostname, lookupResult{mxHosts: mxHosts, mxRecords: mxRecords, shuffled: shuffled, implicit: implicit, err: err}, cacheTTL) - return mxHosts, implicit, err +func cacheAndReturn(hostname string, mxRecords []*net.MX, implicit bool, err error) ([]*net.MX, bool, error) { + lookupResultCache.AddWithTTL(hostname, lookupResult{mxRecords: mxRecords, implicit: implicit, err: err}, cacheTTL) + return mxRecords, implicit, err } // lookupMX exposes the respective private function of net.Resolver. The public diff --git a/mxresolv/mxresolv_test.go b/mxresolv/mxresolv_test.go index 8e95de3..68a4c95 100644 --- a/mxresolv/mxresolv_test.go +++ b/mxresolv/mxresolv_test.go @@ -111,6 +111,76 @@ func TestMain(m *testing.M) { os.Exit(exitVal) } +func TestLookupWithPref(t *testing.T) { + for _, tc := range []struct { + desc string + inDomainName string + outMXHosts []*net.MX + outImplicitMX bool + }{{ + desc: "MX record preference is respected", + inDomainName: "test-mx.definbox.com", + outMXHosts: []*net.MX{ + {Host: "mxa.definbox.com", Pref: 1}, {Host: "mxe.definbox.com", Pref: 1}, {Host: "mxi.definbox.com", Pref: 1}, + {Host: "mxc.definbox.com", Pref: 2}, + {Host: "mxb.definbox.com", Pref: 3}, {Host: "mxd.definbox.com", Pref: 3}, {Host: "mxf.definbox.com", Pref: 3}, {Host: "mxg.definbox.com", Pref: 3}, {Host: "mxh.definbox.com", Pref: 3}, + }, + outImplicitMX: false, + }, { + inDomainName: "test-a.definbox.com", + outMXHosts: []*net.MX{{Host: "test-a.definbox.com", Pref: 1}}, + outImplicitMX: true, + }, { + inDomainName: "test-cname.definbox.com", + outMXHosts: []*net.MX{{Host: "mxa.ninomail.com", Pref: 10}, {Host: "mxb.ninomail.com", Pref: 10}}, + outImplicitMX: false, + }, { + inDomainName: "definbox.com", + outMXHosts: []*net.MX{{Host: "mxa.ninomail.com", Pref: 10}, {Host: "mxb.ninomail.com", Pref: 10}}, + outImplicitMX: false, + }, { + desc: "If an MX host returned by the resolver contains non ASCII " + + "characters then it is silently dropped from the returned list", + inDomainName: "test-unicode.definbox.com", + outMXHosts: []*net.MX{{Host: "mxa.definbox.com", Pref: 1}, {Host: "mxb.definbox.com", Pref: 3}}, + outImplicitMX: false, + }, { + desc: "Underscore is allowed in domain names", + inDomainName: "test-underscore.definbox.com", + outMXHosts: []*net.MX{{Host: "foo_bar.definbox.com", Pref: 1}}, + outImplicitMX: false, + }, { + inDomainName: "test-яндекс.definbox.com", + outMXHosts: []*net.MX{{Host: "xn--test---mofb0ab4b8camvcmn8gxd.definbox.com", Pref: 10}}, + outImplicitMX: false, + }, { + inDomainName: "xn--test--xweh4bya7b6j.definbox.com", + outMXHosts: []*net.MX{{Host: "xn--test---mofb0ab4b8camvcmn8gxd.definbox.com", Pref: 10}}, + outImplicitMX: false, + }, { + inDomainName: "test-mx-ipv4.definbox.com", + outMXHosts: []*net.MX{{Host: "34.150.176.225", Pref: 10}}, + outImplicitMX: false, + }, { + inDomainName: "test-mx-ipv6.definbox.com", + outMXHosts: []*net.MX{{Host: "::ffff:2296:b0e1", Pref: 10}}, + outImplicitMX: false, + }} { + t.Run(tc.inDomainName, func(t *testing.T) { + defer mxresolv.SetDeterministicInTests()() + + // When + ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) + defer cancel() + mxRecords, implicitMX, err := mxresolv.LookupWithPref(ctx, tc.inDomainName) + // Then + assert.NoError(t, err) + assert.Equal(t, tc.outMXHosts, mxRecords) + assert.Equal(t, tc.outImplicitMX, implicitMX) + }) + } +} + func TestLookup(t *testing.T) { for _, tc := range []struct { desc string @@ -172,11 +242,11 @@ func TestLookup(t *testing.T) { // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) defer cancel() - mxHosts, explictMX, err := mxresolv.Lookup(ctx, tc.inDomainName) + mxHosts, implicitMX, err := mxresolv.Lookup(ctx, tc.inDomainName) // Then assert.NoError(t, err) assert.Equal(t, tc.outMXHosts, mxHosts) - assert.Equal(t, tc.outImplicitMX, explictMX) + assert.Equal(t, tc.outImplicitMX, implicitMX) }) } }