Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for logging with context #231

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import (
"sync"
"time"

cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-cleanhttp"
)

var (
Expand Down Expand Up @@ -364,6 +364,43 @@ func (h hookLogger) Printf(s string, args ...interface{}) {
h.Info(fmt.Sprintf(s, args...))
}

// ContextLogger is an interface that provides methods for logging
// with context. The methods accept a context.Context, a message
// string and a variadic number of key-value pairs.
type ContextLogger interface {
ErrorContext(ctx context.Context, msg string, keysAndValues ...interface{})
InfoContext(ctx context.Context, msg string, keysAndValues ...interface{})
DebugContext(ctx context.Context, msg string, keysAndValues ...interface{})
WarnContext(ctx context.Context, msg string, keysAndValues ...interface{})
}

// contextLogger adapts an ContextLogger to Logger for use by the existing hook functions
// without changing the API.
type contextLogger struct {
context.Context
ContextLogger
}

func (s contextLogger) ErrorContext(ctx context.Context, msg string, keysAndValues ...interface{}) {
s.ErrorContext(ctx, msg, keysAndValues...)
}

func (s contextLogger) InfoContext(ctx context.Context, msg string, keysAndValues ...interface{}) {
s.InfoContext(ctx, msg, keysAndValues...)
}

func (s contextLogger) DebugContext(ctx context.Context, msg string, keysAndValues ...interface{}) {
s.DebugContext(ctx, msg, keysAndValues...)
}

func (s contextLogger) WarnContext(ctx context.Context, msg string, keysAndValues ...interface{}) {
s.WarnContext(ctx, msg, keysAndValues...)
}

func (s contextLogger) Printf(msg string, keysAndValues ...interface{}) {
s.InfoContext(s.Context, msg, keysAndValues...)
}

// RequestLogHook allows a function to run before each retry. The HTTP
// request which will be made, and the retry number (0 for the initial
// request) are available to users. The internal logger is exposed to
Expand Down Expand Up @@ -405,7 +442,7 @@ type PrepareRetry func(req *http.Request) error
// like automatic retries to tolerate minor outages.
type Client struct {
HTTPClient *http.Client // Internal HTTP client.
Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger
Logger interface{} // Customer logger instance. Supports Logger, LeveledLogger, or ContextLogger.

RetryWaitMin time.Duration // Minimum time to wait
RetryWaitMax time.Duration // Maximum time to wait
Expand Down Expand Up @@ -456,11 +493,11 @@ func (c *Client) logger() interface{} {
}

switch c.Logger.(type) {
case Logger, LeveledLogger:
case Logger, LeveledLogger, ContextLogger:
// ok
default:
// This should happen in dev when they are setting Logger and work on code, not in prod.
panic(fmt.Sprintf("invalid logger type passed, must be Logger or LeveledLogger, was %T", c.Logger))
panic(fmt.Sprintf("invalid logger type passed, must in Logger, LeveledLogger, ContextLogger, was %T", c.Logger))
}
})

Expand Down Expand Up @@ -657,6 +694,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {

if logger != nil {
switch v := logger.(type) {
case ContextLogger:
v.DebugContext(req.Context(), "performing request", "method", req.Method, "url", redactURL(req.URL))
case LeveledLogger:
v.Debug("performing request", "method", req.Method, "url", redactURL(req.URL))
case Logger:
Expand Down Expand Up @@ -689,6 +728,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {

if c.RequestLogHook != nil {
switch v := logger.(type) {
case ContextLogger:
c.RequestLogHook(contextLogger{req.Context(), v}, req.Request, i)
case LeveledLogger:
c.RequestLogHook(hookLogger{v}, req.Request, i)
case Logger:
Expand All @@ -714,6 +755,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
}
if err != nil {
switch v := logger.(type) {
case ContextLogger:
v.ErrorContext(req.Context(), "request failed", "error", err, "method", req.Method, "url", redactURL(req.URL))
case LeveledLogger:
v.Error("request failed", "error", err, "method", req.Method, "url", redactURL(req.URL))
case Logger:
Expand All @@ -725,6 +768,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
if c.ResponseLogHook != nil {
// Call the response logger function if provided.
switch v := logger.(type) {
case ContextLogger:
c.ResponseLogHook(contextLogger{req.Context(), v}, resp)
case LeveledLogger:
c.ResponseLogHook(hookLogger{v}, resp)
case Logger:
Expand All @@ -748,7 +793,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) {

// We're going to retry, consume any response to reuse the connection.
if doErr == nil {
c.drainBody(resp.Body)
c.drainBody(req.Context(), resp.Body)
}

wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp)
Expand All @@ -758,6 +803,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
desc = fmt.Sprintf("%s (status: %d)", desc, resp.StatusCode)
}
switch v := logger.(type) {
case ContextLogger:
v.DebugContext(req.Context(), "retrying request", "request", desc, "timeout", wait, "remaining", remain)
case LeveledLogger:
v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain)
case Logger:
Expand Down Expand Up @@ -811,7 +858,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
// By default, we close the response body and return an error without
// returning the response
if resp != nil {
c.drainBody(resp.Body)
c.drainBody(req.Context(), resp.Body)
}

// this means CheckRetry thought the request was a failure, but didn't
Expand All @@ -826,12 +873,14 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
}

// Try to read the response body so we can reuse this connection.
func (c *Client) drainBody(body io.ReadCloser) {
func (c *Client) drainBody(ctx context.Context, body io.ReadCloser) {
defer body.Close()
_, err := io.Copy(io.Discard, io.LimitReader(body, respReadLimit))
if err != nil {
if c.logger() != nil {
switch v := c.logger().(type) {
case ContextLogger:
v.ErrorContext(ctx, "error reading response body", "error", err)
case LeveledLogger:
v.Error("error reading response body", "error", err)
case Logger:
Expand Down
7 changes: 7 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ func TestClient_RequestLogHook(t *testing.T) {
t.Run("RequestLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) {
testClientRequestLogHook(t, LeveledLogger(nil))
})
t.Run("RequestLogHook successfully called with nil typed ContextLogger", func(t *testing.T) {
testClientRequestLogHook(t, ContextLogger(nil))
})
}

func testClientRequestLogHook(t *testing.T, logger interface{}) {
Expand Down Expand Up @@ -639,6 +642,10 @@ func TestClient_ResponseLogHook(t *testing.T) {
buf := new(bytes.Buffer)
testClientResponseLogHook(t, LeveledLogger(nil), buf)
})
t.Run("ResponseLogHook successfully called with nil typed ContextLogger", func(t *testing.T) {
buf := new(bytes.Buffer)
testClientResponseLogHook(t, ContextLogger(nil), buf)
})
}

func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) {
Expand Down