Skip to content
This repository has been archived by the owner on Jul 21, 2021. It is now read-only.

Commit

Permalink
Update DNS lookup behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu Xie committed Mar 23, 2018
1 parent c4fab1a commit fb52b7d
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 22 deletions.
92 changes: 70 additions & 22 deletions zk/dnshostprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,107 @@ import (
"fmt"
"net"
"sync"
"time"
)

// lookupInterval is the interval of retrying DNS lookup for unresolved hosts
const lookupInterval = time.Minute * 3

// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
curr int
last int
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
// fields above mu are not thread safe
unresolvedServers map[string]struct{}
sleep func(time.Duration) // Override of time.Sleep, for testing.
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.

mu sync.Mutex // Protects everything below, so we can add asynchronous updates later.
servers []string
curr int
last int
}

// Init is called first, with the servers specified in the connection
// string. It uses DNS to look up addresses for each server, then
// shuffles them all together.
func (hp *DNSHostProvider) Init(servers []string) error {
hp.mu.Lock()
defer hp.mu.Unlock()

lookupHost := hp.lookupHost
if lookupHost == nil {
lookupHost = net.LookupHost
if hp.sleep == nil {
hp.sleep = time.Sleep
}
if hp.lookupHost == nil {
hp.lookupHost = net.LookupHost
}

found := []string{}
hp.servers = make([]string, 0, len(servers))
hp.unresolvedServers = make(map[string]struct{}, len(servers))
for _, server := range servers {
hp.unresolvedServers[server] = struct{}{}
}

done, err := hp.lookupUnresolvedServers()
if err != nil {
return err
}

// as long as any host resolved successfully, consider the connection as success
// but start a lookup loop until all servers are resolved and added to servers list
if !done {
go hp.lookupLoop()
}

return nil
}

// lookupLoop calls lookupUnresolvedServers in an infinite loop until all hosts are resolved
// should be called in a separate goroutine
func (hp *DNSHostProvider) lookupLoop() {
for {
if done, _ := hp.lookupUnresolvedServers(); done {
break
}
hp.sleep(lookupInterval)
}
}

// lookupUnresolvedServers DNS lookup the hosts that not successfully resolved yet
// and add them to servers list
func (hp *DNSHostProvider) lookupUnresolvedServers() (bool, error) {
if len(hp.unresolvedServers) == 0 {
return true, nil
}

found := make([]string, 0, len(hp.unresolvedServers))
for server := range hp.unresolvedServers {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
return false, err
}
addrs, err := lookupHost(host)
addrs, err := hp.lookupHost(host)
if err != nil {
return err
continue
}
delete(hp.unresolvedServers, server)
for _, addr := range addrs {
found = append(found, net.JoinHostPort(addr, port))
}
}

if len(found) == 0 {
return fmt.Errorf("No hosts found for addresses %q", servers)
}

// Randomize the order of the servers to avoid creating hotspots
stringShuffle(found)

hp.servers = found
hp.mu.Lock()
defer hp.mu.Unlock()

hp.servers = append(hp.servers, found...)
hp.curr = -1
hp.last = -1

return nil
if len(hp.servers) == 0 {
return true, fmt.Errorf("No hosts found for addresses %q", hp.servers)
}

return false, nil
}

// Len returns the number of servers available
Expand Down
45 changes: 45 additions & 0 deletions zk/dnshostprovider_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zk

import (
"errors"
"fmt"
"log"
"testing"
Expand Down Expand Up @@ -165,6 +166,50 @@ func TestDNSHostProviderReconnect(t *testing.T) {
}
}

// TestDNSHostOneHostDead tests whether
func TestDNSHostOneHostDead(t *testing.T) {
// use channel to simulate a server that was initially dead but came back online later
ch := make(chan struct{}, 0)
hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
if host != "foo.failure.com" {
return []string{"192.0.2.1", "192.0.2.2"}, nil
}
select {
case <-ch:
return []string{"192.0.2.3"}, nil
default:
return nil, errors.New("Fails to ns lookup")
}
}, sleep: func(_ time.Duration) {}}

if err := hp.Init([]string{"foo.failure.com:12345", "foo.success.com:12345"}); err != nil {
t.Fatal(err)
}

hp.mu.Lock()
if len(hp.servers) != 2 {
t.Fatal("Only servers that resolved by lookupHost should be in servers list")
}
hp.mu.Unlock()

// simulating one server comes back online
close(ch)

// starts a 30s retry loop to wait servers list to be updated
startRetryLoop := time.Now()
for {
time.Sleep(time.Millisecond * 5)
hp.mu.Lock()
if len(hp.servers) == 3 {
break
}
hp.mu.Unlock()
if time.Since(startRetryLoop) > time.Second*30 {
t.Fatal("Servers get back online should be added to the servers list")
}
}
}

// TestDNSHostProviderRetryStart tests the `retryStart` functionality
// of DNSHostProvider.
// It's also probably the clearest visual explanation of exactly how
Expand Down

0 comments on commit fb52b7d

Please sign in to comment.