From 420aadb4dafb13e20c89649dc421175301ec4e50 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 25 Oct 2023 15:38:55 +0200 Subject: [PATCH] feat(dslx): implement DNSLookupParallel (#1387) Closes https://github.com/ooni/probe/issues/2617 While there: there's no need to return the Trace used for DNS lookup and there's no meaning in aggregating traces when running DNSLookupParallel, because we have multiple traces, so let's just drop ResolvedAddresses.Trace. --- internal/dslx/address.go | 8 +++ internal/dslx/dns.go | 44 ++++++++++-- internal/dslx/dns_test.go | 6 -- internal/dslx/qa_test.go | 136 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 13 deletions(-) create mode 100644 internal/dslx/qa_test.go diff --git a/internal/dslx/address.go b/internal/dslx/address.go index fe12c2c73f..c0b3076cdc 100644 --- a/internal/dslx/address.go +++ b/internal/dslx/address.go @@ -57,6 +57,14 @@ func (as *AddressSet) RemoveBogons() *AddressSet { return as } +// Uniq returns the unique addresses. +func (as *AddressSet) Uniq() (uniq []string) { + for addr := range as.M { + uniq = append(uniq, addr) + } + return +} + // EndpointPort is the port for an endpoint. type EndpointPort uint16 diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index f4a55b4edc..38249a87f9 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -6,6 +6,7 @@ package dslx import ( "context" + "errors" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -61,11 +62,6 @@ type ResolvedAddresses struct { // Domain is the domain we resolved. We inherit this field // from the value inside the DomainToResolve. Domain string - - // Trace is the trace we're currently using. This struct is - // created by the various Apply functions using values inside - // the DomainToResolve to initialize the Trace. - Trace Trace } // DNSLookupGetaddrinfo returns a function that resolves a domain name to @@ -109,7 +105,6 @@ func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *ResolvedAddresses] state := &ResolvedAddresses{ Addresses: addrs, Domain: input.Domain, - Trace: trace, } return state, nil }) @@ -161,7 +156,42 @@ func DNSLookupUDP(rt Runtime, endpoint string) Func[*DomainToResolve, *ResolvedA state := &ResolvedAddresses{ Addresses: addrs, Domain: input.Domain, - Trace: trace, + } + return state, nil + }) +} + +// ErrDNSLookupParallel indicates that DNSLookupParallel failed. +var ErrDNSLookupParallel = errors.New("dslx: DNSLookupParallel failed") + +// DNSLookupParallel runs DNS lookups in parallel. On success, this function returns +// a unique list of IP addresses aggregated from all resolvers. On failure, this function +// returns [ErrDNSLookupParallel]. You can always obtain the individual errors by +// processing observations or by creating a per-DNS-resolver pipeline. +func DNSLookupParallel(fxs ...Func[*DomainToResolve, *ResolvedAddresses]) Func[*DomainToResolve, *ResolvedAddresses] { + return Operation[*DomainToResolve, *ResolvedAddresses](func(ctx context.Context, domain *DomainToResolve) (*ResolvedAddresses, error) { + // run all the DNS resolvers in parallel + results := Parallel(ctx, Parallelism(2), domain, fxs...) + + // reduce addresses + addressSet := NewAddressSet() + for _, result := range results { + if err := result.Error; err != nil { + continue + } + addressSet.Add(result.State.Addresses...) + } + uniq := addressSet.Uniq() + + // handle the case where all the DNS resolvers failed + if len(uniq) < 1 { + return nil, ErrDNSLookupParallel + } + + // handle success + state := &ResolvedAddresses{ + Addresses: uniq, + Domain: domain.Domain, } return state, nil }) diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 84d58c1494..3e7f57cd61 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -116,9 +116,6 @@ func TestGetaddrinfo(t *testing.T) { if len(res.State.Addresses) != 1 || res.State.Addresses[0] != "93.184.216.34" { t.Fatal("unexpected addresses") } - if diff := cmp.Diff([]string{"antani"}, res.State.Trace.Tags()); diff != "" { - t.Fatal(diff) - } }) }) } @@ -208,9 +205,6 @@ func TestLookupUDP(t *testing.T) { if len(res.State.Addresses) != 1 || res.State.Addresses[0] != "93.184.216.34" { t.Fatal("unexpected addresses") } - if diff := cmp.Diff([]string{"antani"}, res.State.Trace.Tags()); diff != "" { - t.Fatal(diff) - } }) }) } diff --git a/internal/dslx/qa_test.go b/internal/dslx/qa_test.go new file mode 100644 index 0000000000..f329482d2b --- /dev/null +++ b/internal/dslx/qa_test.go @@ -0,0 +1,136 @@ +package dslx_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/apex/log" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/ooni/netem" + "github.com/ooni/probe-cli/v3/internal/dslx" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netemx" + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +// qaStringLessFunc is an utility function to force cmp.Diff to sort string +// slices before performing comparison so that the order doesn't matter +func qaStringLessFunc(a, b string) bool { + return a < b +} + +func TestDNSLookupQA(t *testing.T) { + // testcase is a test case implemented by this function + type testcase struct { + // name is the test case name + name string + + // newRuntime is the function that creates a new runtime + newRuntime func(netx model.MeasuringNetwork) dslx.Runtime + + // configureDPI configures DPI + configureDPI func(dpi *netem.DPIEngine) + + // domain is the domain to resolve + domain dslx.DomainName + + // expectErr is the expected DNS error or nil + expectErr error + + // expectAddrs contains the expected DNS addresses + expectAddrs []string + } + + cases := []testcase{{ + name: "successful case with minimal runtime", + newRuntime: func(netx model.MeasuringNetwork) dslx.Runtime { + return dslx.NewMinimalRuntime(log.Log, time.Now(), dslx.MinimalRuntimeOptionMeasuringNetwork(netx)) + }, + configureDPI: func(dpi *netem.DPIEngine) { + // nothing + }, + domain: "dns.google", + expectErr: nil, + expectAddrs: []string{"8.8.8.8", "8.8.4.4"}, + }, { + name: "with injected nxdomain error and minimal runtime", + newRuntime: func(netx model.MeasuringNetwork) dslx.Runtime { + return dslx.NewMinimalRuntime(log.Log, time.Now(), dslx.MinimalRuntimeOptionMeasuringNetwork(netx)) + }, + configureDPI: func(dpi *netem.DPIEngine) { + dpi.AddRule(&netem.DPISpoofDNSResponse{ + Addresses: []string{}, // empty to cause NXDOMAIN + Logger: log.Log, + Domain: "dns.google", + }) + }, + domain: "dns.google", + expectErr: dslx.ErrDNSLookupParallel, + expectAddrs: []string{}, + }} + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // create an internet testing scenario + env := netemx.MustNewScenario(netemx.InternetScenario) + defer env.Close() + + // create a dslx.Runtime using the client stack + rt := tc.newRuntime(&netxlite.Netx{ + Underlying: &netxlite.NetemUnderlyingNetworkAdapter{UNet: env.ClientStack}, + }) + defer rt.Close() + + // configure the DPI engine + tc.configureDPI(env.DPIEngine()) + + // create DNS lookup function + function := dslx.DNSLookupParallel( + dslx.DNSLookupGetaddrinfo(rt), + dslx.DNSLookupUDP(rt, net.JoinHostPort(netemx.AddressDNSQuad9Net, "53")), + ) + + // create context + ctx := context.Background() + + // perform DNS lookup + results := function.Apply(ctx, dslx.NewMaybeWithValue(dslx.NewDomainToResolve(tc.domain))) + + // unpack the results + resolvedAddrs, err := results.State, results.Error + + // make sure the error matches expectations + switch { + case err == nil && tc.expectErr == nil: + // nothing + + case err != nil && tc.expectErr == nil: + t.Fatal("expected", tc.expectErr, "got", err) + + case err == nil && tc.expectErr != nil: + t.Fatal("expected", tc.expectErr, "got", err) + + case err != nil && tc.expectErr != nil: + if err.Error() != tc.expectErr.Error() { + t.Fatal("expected", tc.expectErr, "got", err) + } + return // no reason to continue + } + + // make sure that the domain has been correctly copied + if resolvedAddrs.Domain != string(tc.domain) { + t.Fatal("expected", tc.domain, "got", resolvedAddrs.Domain) + } + + // make sure we resolved the expected IP addresses + if diff := cmp.Diff(tc.expectAddrs, resolvedAddrs.Addresses, cmpopts.SortSlices(qaStringLessFunc)); diff != "" { + t.Fatal(diff) + } + + // TODO(bassosimone): make sure the observations are OK + }) + } +}