diff --git a/client.go b/client.go index efee53c..1821f49 100644 --- a/client.go +++ b/client.go @@ -41,7 +41,7 @@ import ( "sync" "time" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" ) var ( @@ -364,6 +364,43 @@ func (h hookLogger) Printf(s string, args ...interface{}) { h.Info(fmt.Sprintf(s, args...)) } +// ContextLogger is an interface that provides methods for logging +// with context. The methods accept a context.Context, a message +// string and a variadic number of key-value pairs. +type ContextLogger interface { + ErrorContext(ctx context.Context, msg string, keysAndValues ...interface{}) + InfoContext(ctx context.Context, msg string, keysAndValues ...interface{}) + DebugContext(ctx context.Context, msg string, keysAndValues ...interface{}) + WarnContext(ctx context.Context, msg string, keysAndValues ...interface{}) +} + +// contextLogger adapts an ContextLogger to Logger for use by the existing hook functions +// without changing the API. +type contextLogger struct { + context.Context + ContextLogger +} + +func (s contextLogger) ErrorContext(ctx context.Context, msg string, keysAndValues ...interface{}) { + s.ErrorContext(ctx, msg, keysAndValues...) +} + +func (s contextLogger) InfoContext(ctx context.Context, msg string, keysAndValues ...interface{}) { + s.InfoContext(ctx, msg, keysAndValues...) +} + +func (s contextLogger) DebugContext(ctx context.Context, msg string, keysAndValues ...interface{}) { + s.DebugContext(ctx, msg, keysAndValues...) +} + +func (s contextLogger) WarnContext(ctx context.Context, msg string, keysAndValues ...interface{}) { + s.WarnContext(ctx, msg, keysAndValues...) +} + +func (s contextLogger) Printf(msg string, keysAndValues ...interface{}) { + s.InfoContext(s.Context, msg, keysAndValues...) +} + // RequestLogHook allows a function to run before each retry. The HTTP // request which will be made, and the retry number (0 for the initial // request) are available to users. The internal logger is exposed to @@ -405,7 +442,7 @@ type PrepareRetry func(req *http.Request) error // like automatic retries to tolerate minor outages. type Client struct { HTTPClient *http.Client // Internal HTTP client. - Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger + Logger interface{} // Customer logger instance. Supports Logger, LeveledLogger, or ContextLogger. RetryWaitMin time.Duration // Minimum time to wait RetryWaitMax time.Duration // Maximum time to wait @@ -456,11 +493,11 @@ func (c *Client) logger() interface{} { } switch c.Logger.(type) { - case Logger, LeveledLogger: + case Logger, LeveledLogger, ContextLogger: // ok default: // This should happen in dev when they are setting Logger and work on code, not in prod. - panic(fmt.Sprintf("invalid logger type passed, must be Logger or LeveledLogger, was %T", c.Logger)) + panic(fmt.Sprintf("invalid logger type passed, must in Logger, LeveledLogger, ContextLogger, was %T", c.Logger)) } }) @@ -657,6 +694,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if logger != nil { switch v := logger.(type) { + case ContextLogger: + v.DebugContext(req.Context(), "performing request", "method", req.Method, "url", redactURL(req.URL)) case LeveledLogger: v.Debug("performing request", "method", req.Method, "url", redactURL(req.URL)) case Logger: @@ -689,6 +728,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if c.RequestLogHook != nil { switch v := logger.(type) { + case ContextLogger: + c.RequestLogHook(contextLogger{req.Context(), v}, req.Request, i) case LeveledLogger: c.RequestLogHook(hookLogger{v}, req.Request, i) case Logger: @@ -714,6 +755,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } if err != nil { switch v := logger.(type) { + case ContextLogger: + v.ErrorContext(req.Context(), "request failed", "error", err, "method", req.Method, "url", redactURL(req.URL)) case LeveledLogger: v.Error("request failed", "error", err, "method", req.Method, "url", redactURL(req.URL)) case Logger: @@ -725,6 +768,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if c.ResponseLogHook != nil { // Call the response logger function if provided. switch v := logger.(type) { + case ContextLogger: + c.ResponseLogHook(contextLogger{req.Context(), v}, resp) case LeveledLogger: c.ResponseLogHook(hookLogger{v}, resp) case Logger: @@ -748,7 +793,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { // We're going to retry, consume any response to reuse the connection. if doErr == nil { - c.drainBody(resp.Body) + c.drainBody(req.Context(), resp.Body) } wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp) @@ -758,6 +803,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { desc = fmt.Sprintf("%s (status: %d)", desc, resp.StatusCode) } switch v := logger.(type) { + case ContextLogger: + v.DebugContext(req.Context(), "retrying request", "request", desc, "timeout", wait, "remaining", remain) case LeveledLogger: v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain) case Logger: @@ -811,7 +858,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { // By default, we close the response body and return an error without // returning the response if resp != nil { - c.drainBody(resp.Body) + c.drainBody(req.Context(), resp.Body) } // this means CheckRetry thought the request was a failure, but didn't @@ -826,12 +873,14 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } // Try to read the response body so we can reuse this connection. -func (c *Client) drainBody(body io.ReadCloser) { +func (c *Client) drainBody(ctx context.Context, body io.ReadCloser) { defer body.Close() _, err := io.Copy(io.Discard, io.LimitReader(body, respReadLimit)) if err != nil { if c.logger() != nil { switch v := c.logger().(type) { + case ContextLogger: + v.ErrorContext(ctx, "error reading response body", "error", err) case LeveledLogger: v.Error("error reading response body", "error", err) case Logger: diff --git a/client_test.go b/client_test.go index cc05e91..a8eab65 100644 --- a/client_test.go +++ b/client_test.go @@ -570,6 +570,9 @@ func TestClient_RequestLogHook(t *testing.T) { t.Run("RequestLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) { testClientRequestLogHook(t, LeveledLogger(nil)) }) + t.Run("RequestLogHook successfully called with nil typed ContextLogger", func(t *testing.T) { + testClientRequestLogHook(t, ContextLogger(nil)) + }) } func testClientRequestLogHook(t *testing.T, logger interface{}) { @@ -639,6 +642,10 @@ func TestClient_ResponseLogHook(t *testing.T) { buf := new(bytes.Buffer) testClientResponseLogHook(t, LeveledLogger(nil), buf) }) + t.Run("ResponseLogHook successfully called with nil typed ContextLogger", func(t *testing.T) { + buf := new(bytes.Buffer) + testClientResponseLogHook(t, ContextLogger(nil), buf) + }) } func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) {