Skip to content

Commit

Permalink
Merge pull request #156 from tamccall/context-fix
Browse files Browse the repository at this point in the history
Fixing issue where api request doesn't use the provided context.
  • Loading branch information
nukosuke authored Jan 9, 2020
2 parents 8f6e34b + debf623 commit f61553e
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 10 deletions.
2 changes: 1 addition & 1 deletion zendesk/attachment.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (wr *writer) open() error {
return err
}

wr.prepareRequest(wr.ctx, req)
req = wr.prepareRequest(wr.ctx, req)
req.Header.Set("Content-Type", "application/binary")

q := req.URL.Query()
Expand Down
43 changes: 42 additions & 1 deletion zendesk/attachment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zendesk

import (
"bytes"
"context"
"crypto/sha1"
"io"
"net/http"
Expand All @@ -12,7 +13,7 @@ import (
)

func TestWrite(t *testing.T) {
file := readFixture(filepath.Join("POST", "upload.json"))
file := readFixture(filepath.Join(http.MethodPost, "upload.json"))
h := sha1.New()
h.Write(file)
expectedSum := h.Sum(nil)
Expand Down Expand Up @@ -49,6 +50,30 @@ func TestWrite(t *testing.T) {
}
}

func TestWriteCancelledContext(t *testing.T) {
mockAPI := newMockAPIWithStatus(http.MethodPost, "ticket.json", 201)
defer mockAPI.Close()

client := newTestClient(mockAPI)

canceled, cancelFunc := context.WithCancel(ctx)
cancelFunc()
w := client.UploadAttachment(canceled, "foo", "bar")

file := []byte("body")
r := bytes.NewBuffer(file)

_, err := io.Copy(w, r)
if err == nil {
t.Fatalf("did not recieve expected error")
}

_, err = w.Close()
if err == nil {
t.Fatal("Did not receive error when closing writer")
}
}

func TestDeleteUpload(t *testing.T) {
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
Expand All @@ -62,6 +87,22 @@ func TestDeleteUpload(t *testing.T) {
}
}

func TestDeleteUploadCanceledContext(t *testing.T) {
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
w.Write(nil)
}))

c := newTestClient(mockAPI)
canceled, cancelFunc := context.WithCancel(ctx)
cancelFunc()

err := c.DeleteUpload(canceled, "foobar")
if err == nil {
t.Fatal("did not get expected error")
}
}

func TestGetAttachment(t *testing.T) {
mockAPI := newMockAPI(http.MethodGet, "attachment.json")
client := newTestClient(mockAPI)
Expand Down
28 changes: 28 additions & 0 deletions zendesk/brand_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zendesk

import (
"context"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -17,6 +18,20 @@ func TestCreateBrand(t *testing.T) {
}
}

func TestCreateBrandCanceledContext(t *testing.T) {
mockAPI := newMockAPIWithStatus(http.MethodPost, "brands.json", http.StatusCreated)
client := newTestClient(mockAPI)
defer mockAPI.Close()

canceled, cancelFunc := context.WithCancel(ctx)
cancelFunc()

_, err := client.CreateBrand(canceled, Brand{})
if err == nil {
t.Fatalf("did not get expected error")
}
}

func TestGetBrand(t *testing.T) {
mockAPI := newMockAPI(http.MethodGet, "brand.json")
client := newTestClient(mockAPI)
Expand Down Expand Up @@ -49,6 +64,19 @@ func TestUpdateBrand(t *testing.T) {
}
}

func TestUpdateBrandCanceledContext(t *testing.T) {
mockAPI := newMockAPIWithStatus(http.MethodPut, "brands.json", http.StatusOK)
client := newTestClient(mockAPI)
defer mockAPI.Close()

canceled, cancelFunc := context.WithCancel(ctx)
cancelFunc()
_, err := client.UpdateBrand(canceled, int64(1234), Brand{})
if err == nil {
t.Fatalf("did not get expected error")
}
}

func TestDeleteBrand(t *testing.T) {
mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
Expand Down
13 changes: 13 additions & 0 deletions zendesk/ticket_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zendesk

import (
"context"
"encoding/json"
"net/http"
"sort"
Expand Down Expand Up @@ -46,6 +47,18 @@ func TestGetTicket(t *testing.T) {
}
}

func TestGetTicketCanceledContext(t *testing.T) {
mockAPI := newMockAPI(http.MethodGet, "ticket.json")
client := newTestClient(mockAPI)
defer mockAPI.Close()
canceled, cancelFunc := context.WithCancel(ctx)
cancelFunc()
_, err := client.GetTicket(canceled, 2)
if err == nil {
t.Fatal("Did not get error when calling with cancelled context")
}
}

// Test the CustomField unmarshalling fails on an invalid value.
// In this case a float64 as CustomField.Value should cause an error.
func TestGetTicketWithInvalidCustomField(t *testing.T) {
Expand Down
21 changes: 13 additions & 8 deletions zendesk/zendesk.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (z *Client) get(ctx context.Context, path string) ([]byte, error) {
return nil, err
}

z.prepareRequest(ctx, req)
req = z.prepareRequest(ctx, req)

resp, err := z.httpClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -123,7 +123,8 @@ func (z *Client) post(ctx context.Context, path string, data interface{}) ([]byt
if err != nil {
return nil, err
}
z.prepareRequest(ctx, req)

req = z.prepareRequest(ctx, req)

resp, err := z.httpClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -157,7 +158,8 @@ func (z *Client) put(ctx context.Context, path string, data interface{}) ([]byte
if err != nil {
return nil, err
}
z.prepareRequest(ctx, req)

req = z.prepareRequest(ctx, req)

resp, err := z.httpClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -186,7 +188,8 @@ func (z *Client) delete(ctx context.Context, path string) error {
if err != nil {
return err
}
z.prepareRequest(ctx, req)

req = z.prepareRequest(ctx, req)

resp, err := z.httpClient.Do(req)
if err != nil {
Expand All @@ -210,10 +213,12 @@ func (z *Client) delete(ctx context.Context, path string) error {
}

// prepare request sets common request variables such as authn and user agent
func (z *Client) prepareRequest(ctx context.Context, req *http.Request) {
req = req.WithContext(ctx)
z.includeHeaders(req)
req.SetBasicAuth(z.credential.Email(), z.credential.Secret())
func (z *Client) prepareRequest(ctx context.Context, req *http.Request) *http.Request {
out := req.WithContext(ctx)
z.includeHeaders(out)
out.SetBasicAuth(z.credential.Email(), z.credential.Secret())

return out
}

// includeHeaders set HTTP headers from client.headers to *http.Request
Expand Down

0 comments on commit f61553e

Please sign in to comment.