diff --git a/zendesk/attachment.go b/zendesk/attachment.go index adf2062c..6166a453 100644 --- a/zendesk/attachment.go +++ b/zendesk/attachment.go @@ -71,7 +71,7 @@ func (wr *writer) open() error { return err } - wr.prepareRequest(wr.ctx, req) + req = wr.prepareRequest(wr.ctx, req) req.Header.Set("Content-Type", "application/binary") q := req.URL.Query() diff --git a/zendesk/attachment_test.go b/zendesk/attachment_test.go index 9c69db77..812d8f47 100644 --- a/zendesk/attachment_test.go +++ b/zendesk/attachment_test.go @@ -2,6 +2,7 @@ package zendesk import ( "bytes" + "context" "crypto/sha1" "io" "net/http" @@ -12,7 +13,7 @@ import ( ) func TestWrite(t *testing.T) { - file := readFixture(filepath.Join("POST", "upload.json")) + file := readFixture(filepath.Join(http.MethodPost, "upload.json")) h := sha1.New() h.Write(file) expectedSum := h.Sum(nil) @@ -49,6 +50,30 @@ func TestWrite(t *testing.T) { } } +func TestWriteCancelledContext(t *testing.T) { + mockAPI := newMockAPIWithStatus(http.MethodPost, "ticket.json", 201) + defer mockAPI.Close() + + client := newTestClient(mockAPI) + + canceled, cancelFunc := context.WithCancel(ctx) + cancelFunc() + w := client.UploadAttachment(canceled, "foo", "bar") + + file := []byte("body") + r := bytes.NewBuffer(file) + + _, err := io.Copy(w, r) + if err == nil { + t.Fatalf("did not recieve expected error") + } + + _, err = w.Close() + if err == nil { + t.Fatal("Did not receive error when closing writer") + } +} + func TestDeleteUpload(t *testing.T) { mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) @@ -62,6 +87,22 @@ func TestDeleteUpload(t *testing.T) { } } +func TestDeleteUploadCanceledContext(t *testing.T) { + mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + w.Write(nil) + })) + + c := newTestClient(mockAPI) + canceled, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + err := c.DeleteUpload(canceled, "foobar") + if err == nil { + t.Fatal("did not get expected error") + } +} + func TestGetAttachment(t *testing.T) { mockAPI := newMockAPI(http.MethodGet, "attachment.json") client := newTestClient(mockAPI) diff --git a/zendesk/brand_test.go b/zendesk/brand_test.go index a3cdfffe..b7ac3127 100644 --- a/zendesk/brand_test.go +++ b/zendesk/brand_test.go @@ -1,6 +1,7 @@ package zendesk import ( + "context" "net/http" "net/http/httptest" "testing" @@ -17,6 +18,20 @@ func TestCreateBrand(t *testing.T) { } } +func TestCreateBrandCanceledContext(t *testing.T) { + mockAPI := newMockAPIWithStatus(http.MethodPost, "brands.json", http.StatusCreated) + client := newTestClient(mockAPI) + defer mockAPI.Close() + + canceled, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + _, err := client.CreateBrand(canceled, Brand{}) + if err == nil { + t.Fatalf("did not get expected error") + } +} + func TestGetBrand(t *testing.T) { mockAPI := newMockAPI(http.MethodGet, "brand.json") client := newTestClient(mockAPI) @@ -49,6 +64,19 @@ func TestUpdateBrand(t *testing.T) { } } +func TestUpdateBrandCanceledContext(t *testing.T) { + mockAPI := newMockAPIWithStatus(http.MethodPut, "brands.json", http.StatusOK) + client := newTestClient(mockAPI) + defer mockAPI.Close() + + canceled, cancelFunc := context.WithCancel(ctx) + cancelFunc() + _, err := client.UpdateBrand(canceled, int64(1234), Brand{}) + if err == nil { + t.Fatalf("did not get expected error") + } +} + func TestDeleteBrand(t *testing.T) { mockAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) diff --git a/zendesk/ticket_test.go b/zendesk/ticket_test.go index d2d4ce82..2bff88e4 100644 --- a/zendesk/ticket_test.go +++ b/zendesk/ticket_test.go @@ -1,6 +1,7 @@ package zendesk import ( + "context" "encoding/json" "net/http" "sort" @@ -46,6 +47,18 @@ func TestGetTicket(t *testing.T) { } } +func TestGetTicketCanceledContext(t *testing.T) { + mockAPI := newMockAPI(http.MethodGet, "ticket.json") + client := newTestClient(mockAPI) + defer mockAPI.Close() + canceled, cancelFunc := context.WithCancel(ctx) + cancelFunc() + _, err := client.GetTicket(canceled, 2) + if err == nil { + t.Fatal("Did not get error when calling with cancelled context") + } +} + // Test the CustomField unmarshalling fails on an invalid value. // In this case a float64 as CustomField.Value should cause an error. func TestGetTicketWithInvalidCustomField(t *testing.T) { diff --git a/zendesk/zendesk.go b/zendesk/zendesk.go index 264ac4f2..33ee3915 100644 --- a/zendesk/zendesk.go +++ b/zendesk/zendesk.go @@ -90,7 +90,7 @@ func (z *Client) get(ctx context.Context, path string) ([]byte, error) { return nil, err } - z.prepareRequest(ctx, req) + req = z.prepareRequest(ctx, req) resp, err := z.httpClient.Do(req) if err != nil { @@ -123,7 +123,8 @@ func (z *Client) post(ctx context.Context, path string, data interface{}) ([]byt if err != nil { return nil, err } - z.prepareRequest(ctx, req) + + req = z.prepareRequest(ctx, req) resp, err := z.httpClient.Do(req) if err != nil { @@ -157,7 +158,8 @@ func (z *Client) put(ctx context.Context, path string, data interface{}) ([]byte if err != nil { return nil, err } - z.prepareRequest(ctx, req) + + req = z.prepareRequest(ctx, req) resp, err := z.httpClient.Do(req) if err != nil { @@ -186,7 +188,8 @@ func (z *Client) delete(ctx context.Context, path string) error { if err != nil { return err } - z.prepareRequest(ctx, req) + + req = z.prepareRequest(ctx, req) resp, err := z.httpClient.Do(req) if err != nil { @@ -210,10 +213,12 @@ func (z *Client) delete(ctx context.Context, path string) error { } // prepare request sets common request variables such as authn and user agent -func (z *Client) prepareRequest(ctx context.Context, req *http.Request) { - req = req.WithContext(ctx) - z.includeHeaders(req) - req.SetBasicAuth(z.credential.Email(), z.credential.Secret()) +func (z *Client) prepareRequest(ctx context.Context, req *http.Request) *http.Request { + out := req.WithContext(ctx) + z.includeHeaders(out) + out.SetBasicAuth(z.credential.Email(), z.credential.Secret()) + + return out } // includeHeaders set HTTP headers from client.headers to *http.Request