From 4b9914c912451298c29e7d0bf9d429c8564d04c6 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 25 Oct 2023 14:15:19 +0200 Subject: [PATCH] refactor(dslx): type Operation func(context, A) (B, error) (#1386) Closes https://github.com/ooni/probe/issues/2615 --- internal/dslx/dns.go | 32 +++++++++++++++------------- internal/dslx/dns_test.go | 14 ++++--------- internal/dslx/fxcore.go | 43 +++++++++++++++----------------------- internal/dslx/http_test.go | 6 ------ internal/dslx/httpcore.go | 20 ++++++++++-------- internal/dslx/httpquic.go | 9 +++----- internal/dslx/httptcp.go | 9 ++++---- internal/dslx/httptls.go | 9 ++++---- internal/dslx/quic.go | 16 +++++++------- internal/dslx/quic_test.go | 4 ++-- internal/dslx/tcp.go | 16 +++++++------- internal/dslx/tcp_test.go | 4 ++-- internal/dslx/tls.go | 16 +++++++------- internal/dslx/tls_test.go | 4 ++-- 14 files changed, 94 insertions(+), 108 deletions(-) diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 61e6ff3771..f4a55b4edc 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -71,7 +71,7 @@ type ResolvedAddresses struct { // DNSLookupGetaddrinfo returns a function that resolves a domain name to // IP addresses using libc's getaddrinfo function. func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *ResolvedAddresses] { - return Operation[*DomainToResolve, *ResolvedAddresses](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { + return Operation[*DomainToResolve, *ResolvedAddresses](func(ctx context.Context, input *DomainToResolve) (*ResolvedAddresses, error) { // create trace trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...) @@ -100,23 +100,25 @@ func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *ResolvedAddresses] // save the observations rt.SaveObservations(maybeTraceToObservations(trace)...) + // handle error case + if err != nil { + return nil, err + } + + // handle success state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty + Addresses: addrs, Domain: input.Domain, Trace: trace, } - - return &Maybe[*ResolvedAddresses]{ - Error: err, - State: state, - } + return state, nil }) } // DNSLookupUDP returns a function that resolves a domain name to // IP addresses using the given DNS-over-UDP resolver. func DNSLookupUDP(rt Runtime, endpoint string) Func[*DomainToResolve, *ResolvedAddresses] { - return Operation[*DomainToResolve, *ResolvedAddresses](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { + return Operation[*DomainToResolve, *ResolvedAddresses](func(ctx context.Context, input *DomainToResolve) (*ResolvedAddresses, error) { // create trace trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...) @@ -150,15 +152,17 @@ func DNSLookupUDP(rt Runtime, endpoint string) Func[*DomainToResolve, *ResolvedA // save the observations rt.SaveObservations(maybeTraceToObservations(trace)...) + // handle error case + if err != nil { + return nil, err + } + + // handle success state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty + Addresses: addrs, Domain: input.Domain, Trace: trace, } - - return &Maybe[*ResolvedAddresses]{ - Error: err, - State: state, - } + return state, nil }) } diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 49c19fcf68..84d58c1494 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -90,11 +90,8 @@ func TestGetaddrinfo(t *testing.T) { if res.Error != mockedErr { t.Fatalf("unexpected error type: %s", res.Error) } - if res.State == nil { - t.Fatal("unexpected nil state") - } - if res.State.Addresses != nil { - t.Fatal("expected empty addresses here") + if res.State != nil { + t.Fatal("expected nil state") } }) @@ -178,11 +175,8 @@ func TestLookupUDP(t *testing.T) { if res.Error != mockedErr { t.Fatalf("unexpected error type: %s", res.Error) } - if res.State == nil { - t.Fatal("unexpected nil state") - } - if res.State.Addresses != nil { - t.Fatal("expected empty addresses here") + if res.State != nil { + t.Fatal("expected nil state") } }) diff --git a/internal/dslx/fxcore.go b/internal/dslx/fxcore.go index 4ed1450dba..8074a4b18d 100644 --- a/internal/dslx/fxcore.go +++ b/internal/dslx/fxcore.go @@ -6,8 +6,6 @@ package dslx import ( "context" - - "github.com/ooni/probe-cli/v3/internal/runtimex" ) // Func is a function f: (context.Context, A) -> B. @@ -16,17 +14,18 @@ type Func[A, B any] interface { } // Operation adapts a golang function to behave like a Func. -type Operation[A, B any] func(ctx context.Context, a A) *Maybe[B] +type Operation[A, B any] func(ctx context.Context, a A) (B, error) // Apply implements Func. func (op Operation[A, B]) Apply(ctx context.Context, a *Maybe[A]) *Maybe[B] { - if a.Error != nil { - return &Maybe[B]{ - Error: a.Error, - State: *new(B), // zero value - } + if err := a.Error; err != nil { + return NewMaybeWithError[B](err) + } + out, err := op(ctx, a.State) + if err != nil { + return NewMaybeWithError[B](err) } - return op(ctx, a.State) + return NewMaybeWithValue(out) } // Maybe is the result of an operation implemented by this package @@ -48,6 +47,14 @@ func NewMaybeWithValue[State any](value State) *Maybe[State] { } } +// NewMaybeWithError constructs a Maybe containing the given error. +func NewMaybeWithError[State any](err error) *Maybe[State] { + return &Maybe[State]{ + Error: err, + State: *new(State), // zero value + } +} + // Compose2 composes two operations such as [TCPConnect] and [TLSHandshake]. func Compose2[A, B, C any](f Func[A, B], g Func[B, C]) Func[A, C] { return &compose2Func[A, B, C]{ @@ -64,21 +71,5 @@ type compose2Func[A, B, C any] struct { // Apply implements Func func (h *compose2Func[A, B, C]) Apply(ctx context.Context, a *Maybe[A]) *Maybe[C] { - mb := h.f.Apply(ctx, a) - runtimex.Assert(mb != nil, "h.f.Apply returned a nil pointer") - - if mb.Error != nil { - return &Maybe[C]{ - Error: mb.Error, - State: *new(C), // zero value - } - } - - mc := h.g.Apply(ctx, mb) - runtimex.Assert(mc != nil, "h.g.Apply returned a nil pointer") - - return &Maybe[C]{ - Error: mc.Error, - State: mc.State, - } + return h.g.Apply(ctx, h.f.Apply(ctx, a)) } diff --git a/internal/dslx/http_test.go b/internal/dslx/http_test.go index 15f7feeb0c..4b83575ecb 100644 --- a/internal/dslx/http_test.go +++ b/internal/dslx/http_test.go @@ -245,9 +245,6 @@ func TestHTTPRequest(t *testing.T) { if res.Error != io.EOF { t.Fatal("not the error we expected") } - if res.State.HTTPResponse != nil { - t.Fatal("expected nil request here") - } }) t.Run("with invalid domain", func(t *testing.T) { @@ -265,9 +262,6 @@ func TestHTTPRequest(t *testing.T) { if res.Error == nil || !strings.HasPrefix(res.Error.Error(), `parse "https://%09/": invalid URL escape "%09"`) { t.Fatal("not the error we expected", res.Error) } - if res.State.HTTPResponse != nil { - t.Fatal("expected nil request here") - } }) t.Run("with port-less address", func(t *testing.T) { diff --git a/internal/dslx/httpcore.go b/internal/dslx/httpcore.go index 683070a4a0..56bd6241e4 100644 --- a/internal/dslx/httpcore.go +++ b/internal/dslx/httpcore.go @@ -102,7 +102,7 @@ func HTTPRequestOptionUserAgent(value string) HTTPRequestOption { // HTTPRequest issues an HTTP request using a transport and returns a response. func HTTPRequest(rt Runtime, options ...HTTPRequestOption) Func[*HTTPConnection, *HTTPResponse] { - return Operation[*HTTPConnection, *HTTPResponse](func(ctx context.Context, input *HTTPConnection) *Maybe[*HTTPResponse] { + return Operation[*HTTPConnection, *HTTPResponse](func(ctx context.Context, input *HTTPConnection) (*HTTPResponse, error) { // setup const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) @@ -140,20 +140,22 @@ func HTTPRequest(rt Runtime, options ...HTTPRequestOption) Func[*HTTPConnection, observations = append(observations, maybeTraceToObservations(input.Trace)...) rt.SaveObservations(observations...) + // handle error case + if err != nil { + return nil, err + } + + // handle success state := &HTTPResponse{ Address: input.Address, Domain: input.Domain, - HTTPRequest: req, // possibly nil - HTTPResponse: resp, // possibly nil - HTTPResponseBodySnapshot: body, // possibly nil + HTTPRequest: req, + HTTPResponse: resp, + HTTPResponseBodySnapshot: body, Network: input.Network, Trace: input.Trace, } - - return &Maybe[*HTTPResponse]{ - Error: err, - State: state, - } + return state, nil }) } diff --git a/internal/dslx/httpquic.go b/internal/dslx/httpquic.go index 18ffaac4a8..aba372b652 100644 --- a/internal/dslx/httpquic.go +++ b/internal/dslx/httpquic.go @@ -17,8 +17,7 @@ func HTTPRequestOverQUIC(rt Runtime, options ...HTTPRequestOption) Func[*QUICCon // HTTPConnectionQUIC converts a QUIC connection into an HTTP connection. func HTTPConnectionQUIC(rt Runtime) Func[*QUICConnection, *HTTPConnection] { - return Operation[*QUICConnection, *HTTPConnection](func(ctx context.Context, input *QUICConnection) *Maybe[*HTTPConnection] { - // create transport + return Operation[*QUICConnection, *HTTPConnection](func(ctx context.Context, input *QUICConnection) (*HTTPConnection, error) { httpTransport := netxlite.NewHTTP3Transport( rt.Logger(), netxlite.NewSingleUseQUICDialer(input.QUICConn), @@ -34,9 +33,7 @@ func HTTPConnectionQUIC(rt Runtime) Func[*QUICConnection, *HTTPConnection] { Trace: input.Trace, Transport: httpTransport, } - return &Maybe[*HTTPConnection]{ - Error: nil, - State: state, - } + + return state, nil }) } diff --git a/internal/dslx/httptcp.go b/internal/dslx/httptcp.go index f281a16235..a42f4c4248 100644 --- a/internal/dslx/httptcp.go +++ b/internal/dslx/httptcp.go @@ -17,7 +17,7 @@ func HTTPRequestOverTCP(rt Runtime, options ...HTTPRequestOption) Func[*TCPConne // HTTPConnectionTCP converts a TCP connection into an HTTP connection. func HTTPConnectionTCP(rt Runtime) Func[*TCPConnection, *HTTPConnection] { - return Operation[*TCPConnection, *HTTPConnection](func(ctx context.Context, input *TCPConnection) *Maybe[*HTTPConnection] { + return Operation[*TCPConnection, *HTTPConnection](func(ctx context.Context, input *TCPConnection) (*HTTPConnection, error) { // TODO(https://github.com/ooni/probe/issues/2534): here we're using the QUIRKY netxlite.NewHTTPTransport // function, but we can probably avoid using it, given that this code is // not using tracing and does not care about those quirks. @@ -26,6 +26,7 @@ func HTTPConnectionTCP(rt Runtime) Func[*TCPConnection, *HTTPConnection] { netxlite.NewSingleUseDialer(input.Conn), netxlite.NewNullTLSDialer(), ) + state := &HTTPConnection{ Address: input.Address, Domain: input.Domain, @@ -35,9 +36,7 @@ func HTTPConnectionTCP(rt Runtime) Func[*TCPConnection, *HTTPConnection] { Trace: input.Trace, Transport: httpTransport, } - return &Maybe[*HTTPConnection]{ - Error: nil, - State: state, - } + + return state, nil }) } diff --git a/internal/dslx/httptls.go b/internal/dslx/httptls.go index 3d0ccc63bb..e5a2128d63 100644 --- a/internal/dslx/httptls.go +++ b/internal/dslx/httptls.go @@ -17,7 +17,7 @@ func HTTPRequestOverTLS(rt Runtime, options ...HTTPRequestOption) Func[*TLSConne // HTTPConnectionTLS converts a TLS connection into an HTTP connection. func HTTPConnectionTLS(rt Runtime) Func[*TLSConnection, *HTTPConnection] { - return Operation[*TLSConnection, *HTTPConnection](func(ctx context.Context, input *TLSConnection) *Maybe[*HTTPConnection] { + return Operation[*TLSConnection, *HTTPConnection](func(ctx context.Context, input *TLSConnection) (*HTTPConnection, error) { // TODO(https://github.com/ooni/probe/issues/2534): here we're using the QUIRKY netxlite.NewHTTPTransport // function, but we can probably avoid using it, given that this code is // not using tracing and does not care about those quirks. @@ -26,6 +26,7 @@ func HTTPConnectionTLS(rt Runtime) Func[*TLSConnection, *HTTPConnection] { netxlite.NewNullDialer(), netxlite.NewSingleUseTLSDialer(input.Conn), ) + state := &HTTPConnection{ Address: input.Address, Domain: input.Domain, @@ -35,9 +36,7 @@ func HTTPConnectionTLS(rt Runtime) Func[*TLSConnection, *HTTPConnection] { Trace: input.Trace, Transport: httpTransport, } - return &Maybe[*HTTPConnection]{ - Error: nil, - State: state, - } + + return state, nil }) } diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index 6ab82c0896..3ad65241c6 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -17,7 +17,7 @@ import ( // QUICHandshake returns a function performing QUIC handshakes. func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[*Endpoint, *QUICConnection] { - return Operation[*Endpoint, *QUICConnection](func(ctx context.Context, input *Endpoint) *Maybe[*QUICConnection] { + return Operation[*Endpoint, *QUICConnection](func(ctx context.Context, input *Endpoint) (*QUICConnection, error) { // create trace trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...) @@ -60,20 +60,22 @@ func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[*Endpoint, *Q // save the observations rt.SaveObservations(maybeTraceToObservations(trace)...) + // handle error case + if err != nil { + return nil, err + } + + // handle success state := &QUICConnection{ Address: input.Address, - QUICConn: quicConn, // possibly nil + QUICConn: quicConn, Domain: input.Domain, Network: input.Network, TLSConfig: config, TLSState: tlsState, Trace: trace, } - - return &Maybe[*QUICConnection]{ - Error: err, - State: state, - } + return state, nil }) } diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 17328fa4f4..d798adc179 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -96,8 +96,8 @@ func TestQUICHandshake(t *testing.T) { if res.Error != tt.expectErr { t.Fatalf("unexpected error: %s", res.Error) } - if res.State == nil || res.State.QUICConn != tt.expectConn { - t.Fatal("unexpected conn") + if res.Error == nil && res.State.QUICConn != tt.expectConn { + t.Fatalf("unexpected conn %v", res.State) } rt.Close() if wasClosed != tt.closed { diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index cc15bd5c69..9a48c148cc 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -14,7 +14,7 @@ import ( // TCPConnect returns a function that establishes TCP connections. func TCPConnect(rt Runtime) Func[*Endpoint, *TCPConnection] { - return Operation[*Endpoint, *TCPConnection](func(ctx context.Context, input *Endpoint) *Maybe[*TCPConnection] { + return Operation[*Endpoint, *TCPConnection](func(ctx context.Context, input *Endpoint) (*TCPConnection, error) { // create trace trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...) @@ -46,18 +46,20 @@ func TCPConnect(rt Runtime) Func[*Endpoint, *TCPConnection] { // save the observations rt.SaveObservations(maybeTraceToObservations(trace)...) + // handle error case + if err != nil { + return nil, err + } + + // handle success state := &TCPConnection{ Address: input.Address, - Conn: conn, // possibly nil + Conn: conn, Domain: input.Domain, Network: input.Network, Trace: trace, } - - return &Maybe[*TCPConnection]{ - Error: err, - State: state, - } + return state, nil }) } diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 900aa94a4e..363b8664bb 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -73,8 +73,8 @@ func TestTCPConnect(t *testing.T) { if res.Error != tt.expectErr { t.Fatalf("unexpected error: %s", res.Error) } - if res.State == nil || res.State.Conn != tt.expectConn { - t.Fatal("unexpected conn") + if res.Error == nil && res.State.Conn != tt.expectConn { + t.Fatalf("unexpected conn %v", res.State) } rt.Close() if wasClosed != tt.closed { diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 1add7d578c..ac4675630d 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -49,7 +49,7 @@ func TLSHandshakeOptionServerName(value string) TLSHandshakeOption { // TLSHandshake returns a function performing TSL handshakes. func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[*TCPConnection, *TLSConnection] { - return Operation[*TCPConnection, *TLSConnection](func(ctx context.Context, input *TCPConnection) *Maybe[*TLSConnection] { + return Operation[*TCPConnection, *TLSConnection](func(ctx context.Context, input *TCPConnection) (*TLSConnection, error) { // keep using the same trace trace := input.Trace @@ -86,19 +86,21 @@ func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[*TCPConnection // save the observations rt.SaveObservations(maybeTraceToObservations(trace)...) + // handle error case + if err != nil { + return nil, err + } + + // handle success state := &TLSConnection{ Address: input.Address, - Conn: conn, // possibly nil + Conn: conn, Domain: input.Domain, Network: input.Network, TLSState: netxlite.MaybeTLSConnectionState(conn), Trace: trace, } - - return &Maybe[*TLSConnection]{ - Error: err, - State: state, - } + return state, nil }) } diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 6d8f8e266b..563ad15460 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -173,8 +173,8 @@ func TestTLSHandshake(t *testing.T) { if res.Error != tt.expectErr { t.Fatalf("unexpected error: %s", res.Error) } - if res.State.Conn != tt.expectConn { - t.Fatalf("unexpected conn %v", res.State.Conn) + if res.State != nil && res.State.Conn != tt.expectConn { + t.Fatalf("unexpected conn %v", res.State) } rt.Close() if wasClosed != tt.closed {