diff --git a/client_ext_test.go b/client_ext_test.go index 9dcec272..502049a7 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -31,7 +31,7 @@ import ( "testing" "time" - connect "connectrpc.com/connect" + "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" @@ -227,6 +227,58 @@ func TestGetNoContentHeaders(t *testing.T) { assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) } +func TestConnectionDropped(t *testing.T) { + t.Parallel() + ctx := context.Background() + for _, protocol := range []string{connect.ProtocolConnect, connect.ProtocolGRPC, connect.ProtocolGRPCWeb} { + var opts []connect.ClientOption + switch protocol { + case connect.ProtocolGRPC: + opts = []connect.ClientOption{connect.WithGRPC()} + case connect.ProtocolGRPCWeb: + opts = []connect.ClientOption{connect.WithGRPCWeb()} + } + t.Run(protocol, func(t *testing.T) { + t.Parallel() + httpClient := httpClientFunc(func(_ *http.Request) (*http.Response, error) { + return nil, io.EOF + }) + client := pingv1connect.NewPingServiceClient( + httpClient, + "http://1.2.3.4", + opts..., + ) + t.Run("unary", func(t *testing.T) { + t.Parallel() + req := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, req) + assert.NotNil(t, err) + if !assert.Equal(t, connect.CodeOf(err), connect.CodeUnavailable) { + t.Logf("err = %v\n%#v", err, err) + } + }) + t.Run("stream", func(t *testing.T) { + t.Parallel() + req := connect.NewRequest[pingv1.CountUpRequest](nil) + svrStream, err := client.CountUp(ctx, req) + if err == nil { + t.Cleanup(func() { + assert.Nil(t, svrStream.Close()) + }) + if !assert.False(t, svrStream.Receive()) { + return + } + err = svrStream.Err() + } + assert.NotNil(t, err) + if !assert.Equal(t, connect.CodeOf(err), connect.CodeUnavailable) { + t.Logf("err = %v\n%#v", err, err) + } + }) + }) + } +} + func TestSpecSchema(t *testing.T) { t.Parallel() mux := http.NewServeMux() @@ -762,3 +814,9 @@ func addUnrecognizedBytes[M proto.Message](msg M, data []byte) M { msg.ProtoReflect().SetUnknown(data) return msg } + +type httpClientFunc func(*http.Request) (*http.Response, error) + +func (fn httpClientFunc) Do(req *http.Request) (*http.Response, error) { + return fn(req) +} diff --git a/duplex_http_call.go b/duplex_http_call.go index 78bf13c2..046dc95a 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -306,6 +306,11 @@ func (d *duplexHTTPCall) makeRequest() { // pipe. Write's check for io.ErrClosedPipe and will convert this to io.EOF. response, err := d.httpClient.Do(d.request) //nolint:bodyclose if err != nil { + if errors.Is(err, io.EOF) { + // We use io.EOF as a sentinel in many places and don't want this + // transport error to be confused for those other situations. + err = io.ErrUnexpectedEOF + } err = wrapIfContextError(err) err = wrapIfLikelyH2CNotConfiguredError(d.request, err) err = wrapIfLikelyWithGRPCNotUsedError(err)