diff --git a/pgconn/errors.go b/pgconn/errors.go index c315739a9..503e76da7 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -59,14 +59,15 @@ func (pe *PgError) SQLState() string { // ConnectError is the error returned when a connection attempt fails. type ConnectError struct { - Config *Config // The configuration that was used in the connection attempt. - msg string - err error + Config *Config // The configuration that was used in the connection attempt. + fallbackConfig *FallbackConfig + msg string + err error } func (e *ConnectError) Error() string { sb := &strings.Builder{} - fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg) + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.fallbackConfig.Host, e.Config.User, e.Config.Database, e.msg) if e.err != nil { fmt.Fprintf(sb, " (%s)", e.err.Error()) } @@ -230,3 +231,18 @@ func (e *NotPreferredError) SafeToRetry() bool { func (e *NotPreferredError) Unwrap() error { return e.err } + +type lookupError struct { + err error + fallbackConfig *FallbackConfig +} + +func (e *lookupError) Error() string { + return e.err.Error() +} + +func (e *lookupError) Unwrap() error { + return e.err +} + +var errIPAddrNotFound = errors.New("ip address not found") diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 0bf03f335..37836fd14 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -152,15 +152,16 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er ctx := octx fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) if err != nil { - return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err} + lookupErr := err.(*lookupError) + return nil, &ConnectError{Config: config, fallbackConfig: lookupErr.fallbackConfig, msg: "hostname resolving error", err: lookupErr.err} } if len(fallbackConfigs) == 0 { - return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfigs[0], msg: "hostname resolving error", err: errIPAddrNotFound} } - foundBestServer := false - var fallbackConfig *FallbackConfig + var usedFallbackConfig *FallbackConfig + var notPreferredFallbackConfig *FallbackConfig for i, fc := range fallbackConfigs { // ConnectTimeout restricts the whole connection process. if config.ConnectTimeout != 0 { @@ -175,10 +176,10 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er } pgConn, err = connect(ctx, config, fc, false) if err == nil { - foundBestServer = true + usedFallbackConfig = fc break } else if pgerr, ok := err.(*PgError); ok { - err = &ConnectError{Config: config, msg: "server error", err: pgerr} + err = &ConnectError{Config: config, fallbackConfig: fc, msg: "server error", err: pgerr} const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist @@ -191,16 +192,17 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er } } else if cerr, ok := err.(*ConnectError); ok { if _, ok := cerr.err.(*NotPreferredError); ok { - fallbackConfig = fc + notPreferredFallbackConfig = fc } } } - if !foundBestServer && fallbackConfig != nil { - pgConn, err = connect(ctx, config, fallbackConfig, true) + if usedFallbackConfig == nil && notPreferredFallbackConfig != nil { + pgConn, err = connect(ctx, config, notPreferredFallbackConfig, true) if pgerr, ok := err.(*PgError); ok { - err = &ConnectError{Config: config, msg: "server error", err: pgerr} + err = &ConnectError{Config: config, fallbackConfig: notPreferredFallbackConfig, msg: "server error", err: pgerr} } + usedFallbackConfig = notPreferredFallbackConfig } if err != nil { @@ -211,7 +213,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err} + return nil, &ConnectError{Config: config, fallbackConfig: usedFallbackConfig, msg: "AfterConnect error", err: err} } } @@ -237,7 +239,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba ips, err := lookupFn(ctx, fb.Host) if err != nil { - lookupErrors = append(lookupErrors, err) + lookupErrors = append(lookupErrors, &lookupError{err: err, fallbackConfig: fb}) continue } @@ -246,7 +248,10 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba if err == nil { port, err := strconv.ParseUint(splitPort, 10, 16) if err != nil { - return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + return nil, &lookupError{ + err: fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err), + fallbackConfig: fb, + } } configs = append(configs, &FallbackConfig{ Host: splitIP, @@ -283,7 +288,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) netConn, err := config.DialFunc(ctx, network, address) if err != nil { - return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } pgConn.conn = netConn @@ -295,7 +300,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() - return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "tls error", err: normalizeTimeoutError(ctx, err)} } pgConn.conn = nbTLSConn @@ -336,7 +341,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.frontend.Send(&startupMsg) if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} } for { @@ -346,7 +351,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -359,26 +364,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err = pgConn.txPasswordMessage(pgConn.config.Password) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationSASL: err = pgConn.scramAuth(msg.AuthMechanisms) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed SASL auth", err: err} } case *pgproto3.AuthenticationGSS: err = pgConn.gssAuth() if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed GSS auth", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -396,7 +401,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return pgConn, nil } pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -407,7 +412,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err} + return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "received unexpected message", err: err} } } } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index b77d21c17..73f55e487 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -379,6 +379,22 @@ func TestConnectWithConnectionRefused(t *testing.T) { } } +func TestConnectErrorHasCorrectHost(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // Presumably nothing is listening on 127.0.0.1:1 and 127.0.0.1:2 + conn, err := pgconn.Connect(ctx, "postgresql://user:password@localhost:1,127.0.0.1:2/database") + if err == nil { + conn.Close(ctx) + t.Fatal("Expected error establishing connection to bad port") + } + require.ErrorContains(t, err, "host=127.0.0.1") + require.ErrorContains(t, err, "dial tcp 127.0.0.1:2") +} + func TestConnectCustomDialer(t *testing.T) { t.Parallel()