Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct host in ConnectError #1937

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions pgconn/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,15 @@ func (pe *PgError) SQLState() string {

// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string
err error
Config *Config // The configuration that was used in the connection attempt.
fallbackConfig *FallbackConfig
msg string
err error
}

func (e *ConnectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.fallbackConfig.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
Expand Down Expand Up @@ -230,3 +231,18 @@ func (e *NotPreferredError) SafeToRetry() bool {
func (e *NotPreferredError) Unwrap() error {
return e.err
}

type lookupError struct {
err error
fallbackConfig *FallbackConfig
}

func (e *lookupError) Error() string {
return e.err.Error()
}

func (e *lookupError) Unwrap() error {
return e.err
}

var errIPAddrNotFound = errors.New("ip address not found")
51 changes: 28 additions & 23 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,16 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
lookupErr := err.(*lookupError)
return nil, &ConnectError{Config: config, fallbackConfig: lookupErr.fallbackConfig, msg: "hostname resolving error", err: lookupErr.err}
}

if len(fallbackConfigs) == 0 {
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfigs[0], msg: "hostname resolving error", err: errIPAddrNotFound}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems referencing fallbackConfigs[0] when len(fallbackConfigs) == 0 would always panic.

}

foundBestServer := false
var fallbackConfig *FallbackConfig
var usedFallbackConfig *FallbackConfig
var notPreferredFallbackConfig *FallbackConfig
for i, fc := range fallbackConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
Expand All @@ -175,10 +176,10 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
}
pgConn, err = connect(ctx, config, fc, false)
if err == nil {
foundBestServer = true
usedFallbackConfig = fc
break
} else if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, fallbackConfig: fc, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
Expand All @@ -191,16 +192,17 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
}
} else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
notPreferredFallbackConfig = fc
}
}
}

if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if usedFallbackConfig == nil && notPreferredFallbackConfig != nil {
pgConn, err = connect(ctx, config, notPreferredFallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, fallbackConfig: notPreferredFallbackConfig, msg: "server error", err: pgerr}
}
usedFallbackConfig = notPreferredFallbackConfig
}

if err != nil {
Expand All @@ -211,7 +213,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
return nil, &ConnectError{Config: config, fallbackConfig: usedFallbackConfig, msg: "AfterConnect error", err: err}
}
}

Expand All @@ -237,7 +239,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba

ips, err := lookupFn(ctx, fb.Host)
if err != nil {
lookupErrors = append(lookupErrors, err)
lookupErrors = append(lookupErrors, &lookupError{err: err, fallbackConfig: fb})
continue
}

Expand All @@ -246,7 +248,10 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
if err == nil {
port, err := strconv.ParseUint(splitPort, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)
return nil, &lookupError{
err: fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err),
fallbackConfig: fb,
}
}
configs = append(configs, &FallbackConfig{
Host: splitIP,
Expand Down Expand Up @@ -283,7 +288,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address)
if err != nil {
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}

pgConn.conn = netConn
Expand All @@ -295,7 +300,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
}

pgConn.conn = nbTLSConn
Expand Down Expand Up @@ -336,7 +341,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
}

for {
Expand All @@ -346,7 +351,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
}

switch msg := msg.(type) {
Expand All @@ -359,26 +364,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed SASL auth", err: err}
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "failed GSS auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
Expand All @@ -396,7 +401,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil
}
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "ValidateConnect failed", err: err}
}
}
return pgConn, nil
Expand All @@ -407,7 +412,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg)
default:
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
return nil, &ConnectError{Config: config, fallbackConfig: fallbackConfig, msg: "received unexpected message", err: err}
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,22 @@ func TestConnectWithConnectionRefused(t *testing.T) {
}
}

func TestConnectErrorHasCorrectHost(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

// Presumably nothing is listening on 127.0.0.1:1 and 127.0.0.1:2
conn, err := pgconn.Connect(ctx, "postgresql://user:password@localhost:1,127.0.0.1:2/database")
if err == nil {
conn.Close(ctx)
t.Fatal("Expected error establishing connection to bad port")
}
require.ErrorContains(t, err, "host=127.0.0.1")
require.ErrorContains(t, err, "dial tcp 127.0.0.1:2")
}

func TestConnectCustomDialer(t *testing.T) {
t.Parallel()

Expand Down
Loading