Skip to content

Commit

Permalink
Merge pull request #1 from justinruggles/client-trace
Browse files Browse the repository at this point in the history
Add httptrace.ClientTrace handling to custom Resolver
  • Loading branch information
cevatbarisyilmaz authored Sep 7, 2020
2 parents 1ac4674 + 8c59eb2 commit 045c3b1
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
28 changes: 28 additions & 0 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ara
import (
"context"
"net"
"net/http/httptrace"
)

// A Resolver looks up hosts.
Expand All @@ -27,9 +28,36 @@ func NewCustomResolver(hosts map[string][]string) Resolver {
}
}

func handleClientTrace(t *httptrace.ClientTrace, host string, records []string) {
if t.DNSStart != nil {
t.DNSStart(httptrace.DNSStartInfo{
Host: host,
})
}
if t.DNSDone != nil {
var addrs []net.IPAddr
for _, rec := range records {
ip := net.ParseIP(rec)
if ip != nil {
addrs = append(addrs, net.IPAddr{
IP: ip,
})
}
}
t.DNSDone(httptrace.DNSDoneInfo{
Addrs: addrs,
})
}
}

func (r *resolver) LookupHost(ctx context.Context, host string) ([]string, error) {
records := r.hosts[host]
if records != nil && len(records) != 0 {
t := httptrace.ContextClientTrace(ctx)
if t != nil {
handleClientTrace(t, host, records)
}

return records, nil
}
return net.DefaultResolver.LookupHost(ctx, host)
Expand Down
37 changes: 36 additions & 1 deletion resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package ara_test

import (
"context"
"github.com/cevatbarisyilmaz/ara"
"net/http/httptrace"
"testing"

"github.com/cevatbarisyilmaz/ara"
)

func TestNewCustomResolver(t *testing.T) {
Expand All @@ -29,3 +31,36 @@ func TestNewCustomResolver(t *testing.T) {
t.Error("no addresses")
}
}

func TestClientTrace(t *testing.T) {
resolver := ara.NewCustomResolver(map[string][]string{"example.com": {"127.0.0.1"}})

var gotDNSStart, gotDNSDone bool
ctx := context.Background()
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
DNSStart: func(info httptrace.DNSStartInfo) {
if info.Host != "example.com" {
t.Error("wrong address in ClientTrace.DNSStart")
}
gotDNSStart = true
},
DNSDone: func(info httptrace.DNSDoneInfo) {
if info.Err != nil {
t.Error("non-nil error in ClientTrace.DNSDone")
}
if len(info.Addrs) != 1 || info.Addrs[0].String() != "127.0.0.1" {
t.Error("wrong IP in ClientTrace.DNSDone")
}
gotDNSDone = true
},
})

resolver.LookupHost(ctx, "example.com")

if !gotDNSStart {
t.Error("ClientTrace.DNSStart not called")
}
if !gotDNSDone {
t.Error("ClientTrace.DNSDone not called")
}
}

0 comments on commit 045c3b1

Please sign in to comment.