diff --git a/.travis.yml b/.travis.yml index 7ee5ebe..6ade40e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,17 +7,16 @@ branches: # skip tags build, we are building branch and master that is enough for # consistenty check and release. Let's use Travis CI resources optimally # for aah framework. - - /^v[0-9]\.[0-9]/ + - /^v[0-9.]+$/ go: - - 1.8 - 1.9 + - "1.10" - tip go_import_path: aahframework.org/ahttp.v0 install: - - git config --global http.https://aahframework.org.followRedirects true - go get -t -v ./... script: diff --git a/LICENSE b/LICENSE index 491ce72..02ad2ea 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2016-2017 Jeevanandam M., https://myjeeva.com +Copyright (c) 2016-2018 Jeevanandam M., https://myjeeva.com Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 6f07137..77aca78 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,21 @@ -# ahttp - aah framework -[![Build Status](https://travis-ci.org/go-aah/ahttp.svg?branch=master)](https://travis-ci.org/go-aah/ahttp) [![codecov](https://codecov.io/gh/go-aah/ahttp/branch/master/graph/badge.svg)](https://codecov.io/gh/go-aah/ahttp/branch/master) [![Go Report Card](https://goreportcard.com/badge/aahframework.org/ahttp.v0-unstable)](https://goreportcard.com/report/aahframework.org/ahttp.v0-unstable) [![Version](https://img.shields.io/badge/version-0.10-blue.svg)](https://github.com/go-aah/ahttp/releases/latest) [![GoDoc](https://godoc.org/aahframework.org/ahttp.v0-unstable?status.svg)](https://godoc.org/aahframework.org/ahttp.v0-unstable) [![License](https://img.shields.io/github/license/go-aah/ahttp.svg)](LICENSE) [![Twitter](https://img.shields.io/badge/twitter-@aahframework-55acee.svg)](https://twitter.com/aahframework) +

+ +

HTTP extension library by aah framework

+

+

+

Build Status Code Coverage Go Report Card Release Version Godoc Twitter @aahframework

+

-***v0.10 [released](https://github.com/go-aah/ahttp/releases/latest) and tagged on Sep 01, 2017*** +HTTP extension Library is used to handle/process Request and Response (headers, body, gzip, etc). -HTTP Library built to process, manipulate Request and Response (headers, body, gzip, etc). +### News -*`ahttp` developed for aah framework. However, it's an independent library, can be used separately with any `Go` language project. Feel free to use it.* + * `v0.11.0` [released](https://github.com/go-aah/ahttp/releases/latest) and tagged on Jul 06, 2018. + +## Installation -# Installation -#### Stable Version - Production Ready ```bash -# install the library go get -u aahframework.org/ahttp.v0 ``` -Visit official website https://aahframework.org to learn more. +Visit official website https://aahframework.org to learn more about `aah` framework. diff --git a/ahttp.go b/ahttp.go index 3014658..90a3589 100644 --- a/ahttp.go +++ b/ahttp.go @@ -1,5 +1,5 @@ // Copyright (c) Jeevanandam M (https://github.com/jeevatkm) -// go-aah/ahttp source code and usage is governed by a MIT style +// aahframework.org/ahttp source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // Package ahttp is to cater HTTP helper methods for aah framework. @@ -8,7 +8,11 @@ package ahttp import ( "io" + "net" "net/http" + "strings" + + "aahframework.org/essentials.v0" ) // HTTP Method names @@ -24,6 +28,13 @@ const ( MethodTrace = http.MethodTrace ) +// URI Protocol scheme names +const ( + SchemeHTTP = "http" + SchemeHTTPS = "https" + SchemeFTP = "ftp" +) + // TimeFormat is the time format to use when generating times in HTTP // headers. It is like time.RFC1123 but hard-codes GMT as the time // zone. The time being formatted must be in UTC for Format to @@ -53,6 +64,7 @@ func AcquireRequest(r *http.Request) *Request { // ReleaseRequest method resets the instance value and puts back to pool. func ReleaseRequest(r *Request) { if r != nil { + r.cleanupMutlipart() r.Reset() requestPool.Put(r) } @@ -83,3 +95,71 @@ func WrapGzipWriter(w io.Writer) ResponseWriter { gr.r = w.(*Response) return gr } + +// Scheme method is to identify value of protocol value. It's is derived +// one, Go language doesn't provide directly. +// +// - `X-Forwarded-Proto` is not empty, returns as-is +// +// - `X-Forwarded-Protocol` is not empty, returns as-is +// +// - `http.Request.TLS` is not nil or `X-Forwarded-Ssl == on` returns `https` +// +// - `X-Url-Scheme` is not empty, returns as-is +// +// - returns `http` +func Scheme(r *http.Request) string { + if scheme := r.Header.Get(HeaderXForwardedProto); scheme != "" { + return scheme + } + + if scheme := r.Header.Get(HeaderXForwardedProtocol); scheme != "" { + return scheme + } + + if r.TLS != nil || r.Header.Get(HeaderXForwardedSsl) == "on" { + return "https" + } + + if scheme := r.Header.Get(HeaderXUrlScheme); scheme != "" { + return scheme + } + + return "http" +} + +// Host method is to correct Host value from HTTP request. +func Host(r *http.Request) string { + if r.URL.Host == "" { + return r.Host + } + return r.URL.Host +} + +// ClientIP method returns remote Client IP address aka Remote IP. +// +// It parses in the order of given headers otherwise it uses default +// default header set `X-Forwarded-For`, `X-Real-IP`, "X-Appengine-Remote-Addr" +// and finally `http.Request.RemoteAddr`. +func ClientIP(r *http.Request, hdrs ...string) string { + if len(hdrs) == 0 { + hdrs = []string{"X-Forwarded-For", "X-Real-IP", "X-Appengine-Remote-Addr"} + } + + for _, hdrKey := range hdrs { + if hv := r.Header.Get(hdrKey); !ess.IsStrEmpty(hv) { + index := strings.Index(hv, ",") + if index == -1 { + return strings.TrimSpace(hv) + } + return strings.TrimSpace(hv[:index]) + } + } + + // Remote Address + if remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return strings.TrimSpace(remoteAddr) + } + + return "" +} diff --git a/content_type.go b/content_type.go index 6d629af..eeb4257 100644 --- a/content_type.go +++ b/content_type.go @@ -36,26 +36,32 @@ var ( // ContentTypeOctetStream content type for bytes. ContentTypeOctetStream = parseMediaType("application/octet-stream") -) -type ( - // ContentType is represents request and response content type values - ContentType struct { - Mime string - Exts []string - Params map[string]string - } + // ContentTypeJavascript content type. + ContentTypeJavascript = parseMediaType("application/javascript; charset=utf-8") + + // ContentTypeEventStream Server-Sent Events content type. + ContentTypeEventStream = parseMediaType("text/event-stream") ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Content-Type methods +// Content-Type //___________________________________ -// IsEqual method compares give Content-Type string with current instance. +// ContentType is represents request and response content type values +type ContentType struct { + Mime string + Exts []string + Params map[string]string +} + +// IsEqual method returns true if its equals to current content-type instance +// otherwise false. // E.g.: // contentType.IsEqual("application/json") +// contentType.IsEqual("application/json; charset=utf-8") func (c *ContentType) IsEqual(contentType string) bool { - return strings.HasPrefix(c.String(), strings.ToLower(contentType)) + return strings.HasPrefix(contentType, c.Mime) } // Charset method returns charset of content-type @@ -65,7 +71,7 @@ func (c *ContentType) IsEqual(contentType string) bool { // // Method returns `utf-8` func (c *ContentType) Charset(defaultCharset string) string { - if v, ok := c.Params["charset"]; ok { + if v, found := c.Params["charset"]; found { return v } return defaultCharset @@ -117,6 +123,6 @@ func (c *ContentType) Raw() string { } // String is stringer interface -func (c *ContentType) String() string { +func (c ContentType) String() string { return c.Raw() } diff --git a/content_type_test.go b/content_type_test.go index 768abfc..5542bd7 100644 --- a/content_type_test.go +++ b/content_type_test.go @@ -6,6 +6,7 @@ package ahttp import ( "net/url" + "runtime" "testing" "aahframework.org/essentials.v0" @@ -41,8 +42,11 @@ func TestHTTPNegotiateContentType(t *testing.T) { req := createRawHTTPRequest(HeaderAccept, "application/json") req.URL, _ = url.Parse("http://localhost:8080/testpath.json") contentType = NegotiateContentType(req) - assert.True(t, contentType.IsEqual("application/json")) - assert.Equal(t, ".json", contentType.Exts[0]) + if runtime.GOOS != "windows" { // due to mime types not exists + assert.NotNil(t, contentType) + assert.True(t, contentType.IsEqual("application/json")) + assert.Equal(t, ".json", contentType.Exts[0]) + } req = createRawHTTPRequest(HeaderAccept, "application/json") req.URL, _ = url.Parse("http://localhost:8080/testpath.html") diff --git a/gzip_response.go b/gzip_response.go index 2c2928a..b73d1a3 100644 --- a/gzip_response.go +++ b/gzip_response.go @@ -11,17 +11,8 @@ import ( "net" "net/http" "sync" - - "aahframework.org/essentials.v0" ) -// GzipResponse extends `ahttp.Response` and provides gzip for response -// bytes before writing them to the underlying response. -type GzipResponse struct { - r *Response - gw *gzip.Writer -} - var ( // GzipLevel holds value from app config. GzipLevel int @@ -39,31 +30,16 @@ var ( ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Package methods +// GzipResponse //___________________________________ -// TODO for old method cleanup - -// GetGzipResponseWriter wraps `http.ResponseWriter`, returns aah framework response -// writer that allows to advantage of response process. -// Deprecated use `WrapGzipWriter` instead. -func GetGzipResponseWriter(w ResponseWriter) ResponseWriter { - gr := grPool.Get().(*GzipResponse) - gr.gw = acquireGzipWriter(w) - gr.r = w.(*Response) - return gr -} - -// PutGzipResponseWiriter method resets and puts the gzip writer into pool. -// Deprecated use `ReleaseResponseWriter` instead. -func PutGzipResponseWiriter(rw ResponseWriter) { - releaseGzipResponse(rw.(*GzipResponse)) +// GzipResponse extends `ahttp.Response` to provides gzip compression for response +// bytes to the underlying response. +type GzipResponse struct { + r *Response + gw *gzip.Writer } -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Response methods -//___________________________________ - // Status method returns HTTP response status code. If status is not yet written // it reurns 0. func (g *GzipResponse) Status() int { @@ -82,9 +58,7 @@ func (g *GzipResponse) Header() http.Header { // Write method writes bytes into Response. func (g *GzipResponse) Write(b []byte) (int, error) { - g.r.setContentTypeIfNotSet(b) g.r.WriteHeader(http.StatusOK) - size, err := g.gw.Write(b) g.r.bytesWritten += size return size, err @@ -97,8 +71,9 @@ func (g *GzipResponse) BytesWritten() int { // Close method closes the writer if possible. func (g *GzipResponse) Close() error { - ess.CloseQuietly(g.gw) - g.gw = nil + if err := g.gw.Close(); err != nil { + return err + } return g.r.Close() } @@ -140,9 +115,9 @@ func (g *GzipResponse) Push(target string, opts *http.PushOptions) error { // releaseGzipResponse method resets and puts the gzip response into pool. func releaseGzipResponse(gw *GzipResponse) { - releaseGzipWriter(gw.gw) - releaseResponse(gw.r) _ = gw.Close() + gwPool.Put(gw.gw) + releaseResponse(gw.r) grPool.Put(gw) } @@ -158,8 +133,3 @@ func acquireGzipWriter(w io.Writer) *gzip.Writer { ngw.Reset(w) return ngw } - -func releaseGzipWriter(gw *gzip.Writer) { - _ = gw.Close() - gwPool.Put(gw) -} diff --git a/gzip_response_test.go b/gzip_response_test.go index fc1df37..c6b4ff1 100644 --- a/gzip_response_test.go +++ b/gzip_response_test.go @@ -21,8 +21,8 @@ import ( func TestHTTPGzipWriter(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { GzipLevel = gzip.BestSpeed - gw := GetGzipResponseWriter(GetResponseWriter(w)) - defer PutGzipResponseWiriter(gw) + gw := WrapGzipWriter(AcquireResponseWriter(w)) + defer ReleaseResponseWriter(gw) gw.Header().Set(HeaderVary, HeaderAcceptEncoding) gw.Header().Set(HeaderContentEncoding, "gzip") @@ -92,7 +92,7 @@ func TestHTTPGzipHijack(t *testing.T) { ngw, _ := gzip.NewWriterLevel(w, GzipLevel) gwPool.Put(ngw) } - gw := WrapGzipWriter(GetResponseWriter(w)) + gw := WrapGzipWriter(AcquireResponseWriter(w)) con, rw, err := gw.(http.Hijacker).Hijack() assert.FailOnError(t, err, "") diff --git a/header.go b/header.go index 3680421..a9d6c03 100644 --- a/header.go +++ b/header.go @@ -79,6 +79,9 @@ const ( HeaderXForwardedHost = "X-Forwarded-Host" HeaderXForwardedPort = "X-Forwarded-Port" HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderXForwardedSsl = "X-Forwarded-Ssl" + HeaderXUrlScheme = "X-Url-Scheme" HeaderXForwardedServer = "X-Forwarded-Server" HeaderXFrameOptions = "X-Frame-Options" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" @@ -167,10 +170,9 @@ func NegotiateEncoding(req *http.Request) *AcceptSpec { // ParseContentType method parses the request header `Content-Type` as per RFC1521. func ParseContentType(req *http.Request) *ContentType { contentType := req.Header.Get(HeaderContentType) - if ess.IsStrEmpty(contentType) { + if contentType == "" { return ContentTypeHTML } - return parseMediaType(contentType) } @@ -281,7 +283,7 @@ func NewLocale(value string) *Locale { //___________________________________ // String is stringer interface. -func (l *Locale) String() string { +func (l Locale) String() string { return l.Raw } diff --git a/header_test.go b/header_test.go index 9c834b9..fea468c 100644 --- a/header_test.go +++ b/header_test.go @@ -94,26 +94,29 @@ func TestHTTPNegotiateLocale(t *testing.T) { func TestHTTPNegotiateEncoding(t *testing.T) { req1 := createRawHTTPRequest(HeaderAcceptEncoding, "compress;q=0.5, gzip;q=1.0") - encoding := NegotiateEncoding(req1) + areq1 := AcquireRequest(req1) + encoding := areq1.AcceptEncoding() + assert.True(t, areq1.IsGzipAccepted) assert.Equal(t, "gzip", encoding.Value) assert.Equal(t, "gzip;q=1.0", encoding.Raw) - assert.True(t, isGzipAccepted(&Request{}, req1)) req2 := createRawHTTPRequest(HeaderAcceptEncoding, "gzip;q=1.0, identity; q=0.5, *;q=0") - encoding = NegotiateEncoding(req2) + areq2 := AcquireRequest(req2) + encoding = areq2.AcceptEncoding() + assert.True(t, areq2.IsGzipAccepted) assert.Equal(t, "gzip", encoding.Value) assert.Equal(t, "gzip;q=1.0", encoding.Raw) - assert.True(t, isGzipAccepted(&Request{}, req1)) req3 := createRawHTTPRequest(HeaderAcceptEncoding, "") encoding = NegotiateEncoding(req3) assert.Equal(t, true, encoding == nil) req4 := createRawHTTPRequest(HeaderAcceptEncoding, "compress;q=0.5") - encoding = NegotiateEncoding(req4) + areq4 := AcquireRequest(req4) + encoding = areq4.AcceptEncoding() + assert.False(t, areq4.IsGzipAccepted) assert.Equal(t, "compress", encoding.Value) assert.Equal(t, "compress;q=0.5", encoding.Raw) - assert.False(t, isGzipAccepted(&Request{}, req4)) } func TestHTTPAcceptHeaderVendorType(t *testing.T) { diff --git a/max_bytes_reader.go.bak b/max_bytes_reader.go.bak new file mode 100644 index 0000000..6a93489 --- /dev/null +++ b/max_bytes_reader.go.bak @@ -0,0 +1,80 @@ +package ahttp + +import ( + "errors" + "fmt" + "io" + "sync" +) + +var maxBytesReaderPool = &sync.Pool{New: func() interface{} { return &maxBytesReader{} }} + +// maxBytesReader is a minimal version of net/http package maxBytesReader for aah. +// so that we do memory pool, much more. +// +// MaxBytesReader's result is a ReadCloser, returns a +// non-EOF error for a Read beyond the limit, and closes the +// underlying reader when its Close method is called. +// +// MaxBytesReader prevents clients from accidentally or maliciously +// sending a large request and wasting server resources. +type maxBytesReader struct { + w ResponseWriter + r io.ReadCloser // underlying reader + n int64 // max bytes remaining + err error // sticky error +} + +func (mr *maxBytesReader) Read(p []byte) (n int, err error) { + if mr.err != nil { + return 0, mr.err + } + + if len(p) == 0 { + return 0, nil + } + + // If they asked for a 32KB byte read but only 5 bytes are + // remaining, no need to read 32KB. 6 bytes will answer the + // question of the whether we hit the limit or go past it. + if int64(len(p)) > mr.n+1 { + p = p[:mr.n+1] + } + + n, err = mr.r.Read(p) + if int64(n) <= mr.n { + mr.n -= int64(n) + mr.err = err + return n, err + } + + n = int(mr.n) + mr.n = 0 + + // Set the header to close the connection + mr.w.Header().Set(HeaderConnection, "close") + mr.err = errors.New("ahttp: request body too large") + _ = mr.Close() + return n, mr.err +} + +func (mr *maxBytesReader) Close() error { + fmt.Println("maxBytesReader close called") + return mr.r.Close() +} + +func (mr *maxBytesReader) Reset() { + mr.w = nil + mr.r = nil + mr.n = 0 + mr.err = nil +} + +func releaseMaxBytesReader(r *Request) { + if r.Raw.Body != nil { + if mr, ok := r.Raw.Body.(*maxBytesReader); ok { + mr.Reset() + maxBytesReaderPool.Put(mr) + } + } +} diff --git a/request.go b/request.go index bef424b..0039f41 100644 --- a/request.go +++ b/request.go @@ -1,5 +1,5 @@ // Copyright (c) Jeevanandam M (https://github.com/jeevatkm) -// go-aah/ahttp source code and usage is governed by a MIT style +// aahframework.org/ahttp source code and usage is governed by a MIT style // license that can be found in the LICENSE file. package ahttp @@ -9,11 +9,9 @@ import ( "fmt" "io" "mime/multipart" - "net" "net/http" "net/url" "os" - "path/filepath" "strings" "sync" @@ -21,119 +19,124 @@ import ( ) const ( - jsonpReqParamKey = "callback" - ajaxHeaderValue = "XMLHttpRequest" - websocketHeaderValue = "websocket" + jsonpReqParamKey = "callback" + ajaxHeaderValue = "XMLHttpRequest" ) var requestPool = &sync.Pool{New: func() interface{} { return &Request{} }} -type ( - // Request is extends `http.Request` for aah framework - Request struct { - // Scheme value is protocol; it's a derived value in the order as below. - // - `X-Forwarded-Proto` is not empty return value as is - // - `http.Request.TLS` is not nil value is `https` - // - `http.Request.TLS` is nil value is `http` - Scheme string - - // Host value of the HTTP 'Host' header (e.g. 'example.com:8080'). - Host string - - // Proto value of the current HTTP request protocol. (e.g. HTTP/1.1, HTTP/2.0) - Proto string +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// Package methods +//___________________________________ - // Method request method e.g. `GET`, `POST`, etc. - Method string +// ParseRequest method populates the given aah framework `ahttp.Request` +// instance from Go HTTP request. +func ParseRequest(r *http.Request, req *Request) *Request { + req.Scheme = Scheme(r) + req.Host = Host(r) + req.Proto = r.Proto + req.Method = r.Method + req.Path = r.URL.Path + req.Header = r.Header + req.Referer = getReferer(r.Header) + req.UserAgent = r.Header.Get(HeaderUserAgent) + req.IsGzipAccepted = strings.Contains(r.Header.Get(HeaderAcceptEncoding), "gzip") + req.raw = r + return req +} - // Path the request URL Path e.g. `/app/login.html`. - Path string +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// Request +//___________________________________ - // Header request HTTP headers - Header http.Header +// Request type extends `http.Request` and provides multiple helper methods +// per industry RFC guideline for aah framework. +type Request struct { + // Scheme value is protocol, refer to method `ahttp.Scheme`. + Scheme string - // ContentType the parsed value of HTTP header `Content-Type`. - // Partial implementation as per RFC1521. - ContentType *ContentType + // Host value is HTTP 'Host' header (e.g. 'example.com:8080'). + Host string - // AcceptContentType negotiated value from HTTP Header `Accept`. - // The resolve order is- - // 1) URL extension - // 2) Accept header (As per RFC7231 and vendor type as per RFC4288) - // Most quailfied one based on quality factor otherwise default is HTML. - AcceptContentType *ContentType + // Proto value is current HTTP request protocol. (e.g. HTTP/1.1, HTTP/2.0) + Proto string - // AcceptEncoding negotiated value from HTTP Header the `Accept-Encoding` - // As per RFC7231. - // Most quailfied one based on quality factor. - AcceptEncoding *AcceptSpec + // Method value is HTTP verb from request e.g. `GET`, `POST`, etc. + Method string - // Params contains values from Path, Query, Form and File. - Params *Params + // Path value is request relative URL Path e.g. `/app/login.html`. + Path string - // Referer value of the HTTP 'Referrer' (or 'Referer') header. - Referer string + // Header is request HTTP headers + Header http.Header - // UserAgent value of the HTTP 'User-Agent' header. - UserAgent string + // PathParams value is URL path parameters. + PathParams PathParams - // ClientIP remote client IP address aka Remote IP. Parsed in the order of - // `X-Forwarded-For`, `X-Real-IP` and finally `http.Request.RemoteAddr`. - ClientIP string + // Referer value is HTTP 'Referrer' (or 'Referer') header. + Referer string - // Locale negotiated value from HTTP Header `Accept-Language`. - // As per RFC7231. - Locale *Locale + // UserAgent value is HTTP 'User-Agent' header. + UserAgent string - // IsGzipAccepted is true if the HTTP client accepts Gzip response, - // otherwise false. - IsGzipAccepted bool + // IsGzipAccepted is true if the HTTP client accepts Gzip response, + // otherwise false. + IsGzipAccepted bool - // Raw an object of Go HTTP server, direct interaction with - // raw object is not encouraged. - // - // DEPRECATED: Raw field to be unexported on v1 release, use `Req.Unwarp()` instead. - Raw *http.Request - } + raw *http.Request + locale *Locale + contentType *ContentType + acceptContentType *ContentType + acceptEncoding *AcceptSpec +} - // Params structure holds value of Path, Query, Form and File. - Params struct { - Path map[string]string - Query url.Values - Form url.Values - File map[string][]*multipart.FileHeader +// AcceptContentType method returns negotiated value. +// +// The resolve order is- +// +// 1) URL extension +// +// 2) Accept header (As per RFC7231 and vendor type as per RFC4288) +// +// Most quailfied one based on quality factor otherwise default is Plain text. +func (r *Request) AcceptContentType() *ContentType { + if r.acceptContentType == nil { + r.acceptContentType = NegotiateContentType(r.Unwrap()) } -) + return r.acceptContentType +} -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Package methods -//___________________________________ +// SetAcceptContentType method is used to set Accept ContentType instance. +func (r *Request) SetAcceptContentType(contentType *ContentType) *Request { + r.acceptContentType = contentType + return r +} -// ParseRequest method populates the given aah framework `ahttp.Request` -// instance from Go HTTP request. -func ParseRequest(r *http.Request, req *Request) *Request { - req.Scheme = identifyScheme(r) - req.Host = host(r) - req.Proto = r.Proto - req.Method = r.Method - req.Path = r.URL.Path - req.Header = r.Header - req.ContentType = ParseContentType(r) - req.AcceptContentType = NegotiateContentType(r) - req.Params = &Params{Query: r.URL.Query()} - req.Referer = getReferer(r.Header) - req.UserAgent = r.Header.Get(HeaderUserAgent) - req.ClientIP = clientIP(r) - req.Locale = NegotiateLocale(r) - req.IsGzipAccepted = isGzipAccepted(req, r) - req.Raw = r +// AcceptEncoding method returns negotiated value from HTTP Header the `Accept-Encoding` +// As per RFC7231. +// +// Most quailfied one based on quality factor. +func (r *Request) AcceptEncoding() *AcceptSpec { + if r.acceptEncoding == nil { + if specs := ParseAcceptEncoding(r.Unwrap()); specs != nil { + r.acceptEncoding = specs.MostQualified() + } + } + return r.acceptEncoding +} - return req +// SetAcceptEncoding method is used to accept encoding spec instance. +func (r *Request) SetAcceptEncoding(encoding *AcceptSpec) *Request { + r.acceptEncoding = encoding + return r } -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Request methods -//___________________________________ +// ClientIP method returns remote client IP address aka Remote IP. +// +// Refer to method `ahttp.ClientIP`. +func (r *Request) ClientIP() string { + return ClientIP(r.Unwrap()) +} // Cookie method returns a named cookie from HTTP request otherwise error. func (r *Request) Cookie(name string) (*http.Cookie, error) { @@ -145,6 +148,35 @@ func (r *Request) Cookies() []*http.Cookie { return r.Unwrap().Cookies() } +// ContentType method returns the parsed value of HTTP header `Content-Type` per RFC1521. +func (r *Request) ContentType() *ContentType { + if r.contentType == nil { + r.contentType = ParseContentType(r.Unwrap()) + } + return r.contentType +} + +// SetContentType method is used to set ContentType instance. +func (r *Request) SetContentType(contType *ContentType) *Request { + r.contentType = contType + return r +} + +// Locale method returns negotiated value from HTTP Header `Accept-Language` +// per RFC7231. +func (r *Request) Locale() *Locale { + if r.locale == nil { + r.locale = NegotiateLocale(r.Unwrap()) + } + return r.locale +} + +// SetLocale method is used to set locale instance in to aah request. +func (r *Request) SetLocale(locale *Locale) *Request { + r.locale = locale + return r +} + // IsJSONP method returns true if request URL query string has "callback=function_name". // otherwise false. func (r *Request) IsJSONP() bool { @@ -157,44 +189,52 @@ func (r *Request) IsAJAX() bool { return r.Header.Get(HeaderXRequestedWith) == ajaxHeaderValue } -// IsWebSocket method returns true if request is WebSocket otherwise false. -func (r *Request) IsWebSocket() bool { - return r.Header.Get(HeaderUpgrade) == websocketHeaderValue +// URL method return underlying request URL instance. +func (r *Request) URL() *url.URL { + return r.Unwrap().URL } // PathValue method returns value for given Path param key otherwise empty string. // For eg.: /users/:userId => PathValue("userId") func (r *Request) PathValue(key string) string { - return r.Params.PathValue(key) + return r.PathParams.Get(key) } // QueryValue method returns value for given URL query param key // otherwise empty string. func (r *Request) QueryValue(key string) string { - return r.Params.QueryValue(key) + return r.URL().Query().Get(key) } // QueryArrayValue method returns array value for given URL query param key // otherwise empty string slice. func (r *Request) QueryArrayValue(key string) []string { - return r.Params.QueryArrayValue(key) + if values, found := r.URL().Query()[key]; found { + return values + } + return []string{} } // FormValue method returns value for given form key otherwise empty string. func (r *Request) FormValue(key string) string { - return r.Params.FormValue(key) + return r.Unwrap().FormValue(key) } // FormArrayValue method returns array value for given form key // otherwise empty string slice. func (r *Request) FormArrayValue(key string) []string { - return r.Params.FormArrayValue(key) + if r.Unwrap().Form != nil { + if values, found := r.Unwrap().Form[key]; found { + return values + } + } + return []string{} } // FormFile method returns the first file for the provided form key otherwise // returns error. It is caller responsibility to close the file. func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { - return r.Params.FormFile(key) + return r.Unwrap().FormFile(key) } // Body method returns the HTTP request body. @@ -202,9 +242,10 @@ func (r *Request) Body() io.ReadCloser { return r.Unwrap().Body } -// Unwrap method returns the underlying *http.Request. +// Unwrap method returns the underlying *http.Request instance of Go HTTP server, +// direct interaction with raw object is not encouraged. Use it appropriately. func (r *Request) Unwrap() *http.Request { - return r.Raw + return r.raw } // SaveFile method saves an uploaded multipart file for given key from the HTTP @@ -227,37 +268,6 @@ func (r *Request) SaveFile(key, dstFile string) (int64, error) { return saveFile(uploadedFile, dstFile) } -// SaveFiles method saves an uploaded multipart file(s) for the given key -// from the HTTP request into given destination directory. It uses the filename -// as uploaded filename from the request -func (r *Request) SaveFiles(key, dstPath string) ([]int64, []error) { - if !ess.IsDir(dstPath) { - return []int64{0}, []error{fmt.Errorf("ahttp: destination path, '%s' is not a directory", dstPath)} - } - - if ess.IsStrEmpty(key) { - return []int64{0}, []error{fmt.Errorf("ahttp: form file key, '%s' is empty", key)} - } - - var errs []error - var sizes []int64 - for _, file := range r.Params.File[key] { - uploadedFile, err := file.Open() - if err != nil { - sizes = append(sizes, 0) - errs = append(errs, err) - continue - } - - if size, err := saveFile(uploadedFile, filepath.Join(dstPath, file.Filename)); err != nil { - sizes = append(sizes, size) - errs = append(errs, err) - } - ess.CloseQuietly(uploadedFile) - } - return sizes, errs -} - // Reset method resets request instance for reuse. func (r *Request) Reset() { r.Scheme = "" @@ -266,158 +276,56 @@ func (r *Request) Reset() { r.Method = "" r.Path = "" r.Header = nil - r.ContentType = nil - r.AcceptContentType = nil - r.AcceptEncoding = nil - r.Params = nil + r.PathParams = nil r.Referer = "" r.UserAgent = "" - r.ClientIP = "" - r.Locale = nil r.IsGzipAccepted = false - r.Raw = nil -} -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Params methods -//___________________________________ + r.raw = nil + r.locale = nil + r.contentType = nil + r.acceptContentType = nil + r.acceptEncoding = nil +} -// PathValue method returns value for given Path param key otherwise empty string. -// For eg.: `/users/:userId` => `PathValue("userId")`. -func (p *Params) PathValue(key string) string { - if p.Path != nil { - if value, found := p.Path[key]; found { - return value - } +func (r *Request) cleanupMutlipart() { + if r.Unwrap().MultipartForm != nil { + r.Unwrap().MultipartForm.RemoveAll() } - return "" } -// QueryValue method returns value for given URL query param key -// otherwise empty string. -func (p *Params) QueryValue(key string) string { - return p.Query.Get(key) -} +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// PathParams +//___________________________________ -// QueryArrayValue method returns array value for given URL query param key -// otherwise empty string slice. -func (p *Params) QueryArrayValue(key string) []string { - if values, found := p.Query[key]; found { - return values - } - return []string{} -} +// PathParams struct holds the path parameter key and values. +type PathParams map[string]string -// FormValue method returns value for given form key otherwise empty string. -func (p *Params) FormValue(key string) string { - if p.Form != nil { - return p.Form.Get(key) +// Get method returns the value for the given key otherwise empty string. +func (p PathParams) Get(key string) string { + if value, found := p[key]; found { + return value } return "" } -// FormArrayValue method returns array value for given form key -// otherwise empty string slice. -func (p *Params) FormArrayValue(key string) []string { - if p.Form != nil { - if values, found := p.Form[key]; found { - return values - } - } - return []string{} -} - -// FormFile method returns the first file for the provided form key -// otherwise returns error. It is caller responsibility to close the file. -func (p *Params) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { - if p.File != nil { - if fh := p.File[key]; len(fh) > 0 { - f, err := fh[0].Open() - return f, fh[0], err - } - return nil, nil, fmt.Errorf("ahttp: no such key/file: %s", key) - } - return nil, nil, nil +// Len method returns count of total no. of values. +func (p PathParams) Len() int { + return len(p) } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Unexported methods //___________________________________ -// identifyScheme method is to identify value of protocol value. It's is derived -// one, Go language doesn't provide directly. -// - `X-Forwarded-Proto` is not empty return value as is -// - `http.Request.TLS` is not nil value is `https` -// - `http.Request.TLS` is nil value is `http` -func identifyScheme(r *http.Request) string { - scheme := r.Header.Get(HeaderXForwardedProto) - if !ess.IsStrEmpty(scheme) { - return scheme - } - - if r.TLS != nil { - return "https" - } - - return "http" -} - -// clientIP returns IP address from HTTP request, typically known as Client IP or -// Remote IP. It parses the IP in the order of X-Forwarded-For, X-Real-IP -// and finally `http.Request.RemoteAddr`. -func clientIP(req *http.Request) string { - // Header X-Forwarded-For - if fwdFor := req.Header.Get(HeaderXForwardedFor); !ess.IsStrEmpty(fwdFor) { - index := strings.Index(fwdFor, ",") - if index == -1 { - return strings.TrimSpace(fwdFor) - } - return strings.TrimSpace(fwdFor[:index]) - } - - // Header X-Real-Ip - if realIP := req.Header.Get(HeaderXRealIP); !ess.IsStrEmpty(realIP) { - return strings.TrimSpace(realIP) - } - - // Remote Address - if remoteAddr, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - return strings.TrimSpace(remoteAddr) - } - - return "" -} - -func host(r *http.Request) string { - if ess.IsStrEmpty(r.URL.Host) { - return r.Host - } - return r.URL.Host -} - func getReferer(hdr http.Header) string { referer := hdr.Get(HeaderReferer) - - if ess.IsStrEmpty(referer) { - referer = hdr.Get("Referrer") + if referer == "" { + return hdr.Get("Referrer") } - return referer } -func isGzipAccepted(req *Request, r *http.Request) bool { - specs := ParseAcceptEncoding(r) - if specs != nil { - req.AcceptEncoding = specs.MostQualified() - for _, v := range specs { - if v.Value == "gzip" { - return true - } - } - } - return false -} - func saveFile(r io.Reader, destFile string) (int64, error) { f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) if err != nil { diff --git a/request_test.go b/request_test.go index 107c930..1e7f408 100644 --- a/request_test.go +++ b/request_test.go @@ -7,6 +7,7 @@ package ahttp import ( "bytes" "crypto/tls" + "errors" "mime/multipart" "net/http" "net/http/httptest" @@ -21,23 +22,23 @@ import ( func TestHTTPClientIP(t *testing.T) { req1 := createRawHTTPRequest(HeaderXForwardedFor, "10.0.0.1, 10.0.0.2") - ipAddress := clientIP(req1) + ipAddress := AcquireRequest(req1).ClientIP() assert.Equal(t, "10.0.0.1", ipAddress) req2 := createRawHTTPRequest(HeaderXForwardedFor, "10.0.0.2") - ipAddress = clientIP(req2) + ipAddress = AcquireRequest(req2).ClientIP() assert.Equal(t, "10.0.0.2", ipAddress) req3 := createRawHTTPRequest(HeaderXRealIP, "10.0.0.3") - ipAddress = clientIP(req3) + ipAddress = AcquireRequest(req3).ClientIP() assert.Equal(t, "10.0.0.3", ipAddress) req4 := createRequestWithHost("127.0.0.1:8080", "192.168.0.1:1234") - ipAddress = clientIP(req4) + ipAddress = AcquireRequest(req4).ClientIP() assert.Equal(t, "192.168.0.1", ipAddress) req5 := createRequestWithHost("127.0.0.1:8080", "") - ipAddress = clientIP(req5) + ipAddress = AcquireRequest(req5).ClientIP() assert.Equal(t, "", ipAddress) } @@ -62,14 +63,15 @@ func TestHTTPParseRequest(t *testing.T) { req.URL, _ = url.Parse("/welcome1.html?_ref=true") aahReq := AcquireRequest(req) + assert.True(t, req.URL == aahReq.URL()) assert.Equal(t, req, aahReq.Unwrap()) assert.Equal(t, "127.0.0.1:8080", aahReq.Host) assert.Equal(t, MethodGet, aahReq.Method) assert.Equal(t, "/welcome1.html", aahReq.Path) assert.Equal(t, "en-gb;leve=1;q=0.8, da, en;level=2;q=0.7, en-us;q=gg", aahReq.Header.Get(HeaderAcceptLanguage)) - assert.Equal(t, "application/json; charset=utf-8", aahReq.ContentType.String()) - assert.Equal(t, "192.168.0.1", aahReq.ClientIP) + assert.Equal(t, "application/json; charset=utf-8", aahReq.ContentType().String()) + assert.Equal(t, "192.168.0.1", aahReq.ClientIP()) assert.Equal(t, "http://localhost:8080/home.html", aahReq.Referer) // Query Value @@ -87,22 +89,24 @@ func TestHTTPParseRequest(t *testing.T) { f, hdr, err := aahReq.FormFile("no_file") assert.Nil(t, f) assert.Nil(t, hdr) - assert.Nil(t, err) + assert.NotNil(t, err) // request Content-Type isn't multipart/form-data assert.False(t, aahReq.IsJSONP()) assert.False(t, aahReq.IsAJAX()) - assert.False(t, aahReq.IsWebSocket()) - // Reset it - aahReq.Reset() - assert.Nil(t, aahReq.Header) - assert.Nil(t, aahReq.ContentType) - assert.Nil(t, aahReq.AcceptContentType) - assert.Nil(t, aahReq.Params) - assert.Nil(t, aahReq.Locale) - assert.Nil(t, aahReq.Raw) - assert.True(t, len(aahReq.UserAgent) == 0) - assert.True(t, len(aahReq.ClientIP) == 0) + aahReq.SetAcceptContentType(nil) + assert.NotNil(t, aahReq.AcceptContentType()) + aahReq.SetLocale(nil) + assert.NotNil(t, aahReq.Locale()) + aahReq.SetContentType(nil) + assert.NotNil(t, aahReq.ContentType()) + aahReq.SetAcceptEncoding(nil) + assert.Nil(t, aahReq.AcceptEncoding()) + + // Release it ReleaseRequest(aahReq) + assert.Nil(t, aahReq.Header) + assert.Nil(t, aahReq.raw) + assert.True(t, aahReq.UserAgent == "") } func TestHTTPRequestParams(t *testing.T) { @@ -111,15 +115,17 @@ func TestHTTPRequestParams(t *testing.T) { req1.Method = MethodPost req1.URL, _ = url.Parse("http://localhost:8080/welcome1.html?_ref=true&names=Test1&names=Test%202") - params1 := AcquireRequest(req1).Params - params1.Path = make(map[string]string) - params1.Path["userId"] = "100001" - assert.Equal(t, "true", params1.QueryValue("_ref")) - assert.Equal(t, "Test1", params1.QueryArrayValue("names")[0]) - assert.Equal(t, "Test 2", params1.QueryArrayValue("names")[1]) - assert.True(t, len(params1.QueryArrayValue("not-exists")) == 0) - assert.Equal(t, "100001", params1.PathValue("userId")) - assert.Equal(t, "", params1.PathValue("accountId")) + aahReq1 := AcquireRequest(req1) + aahReq1.PathParams = PathParams{} + aahReq1.PathParams["userId"] = "100001" + + assert.Equal(t, "true", aahReq1.QueryValue("_ref")) + assert.Equal(t, "Test1", aahReq1.QueryArrayValue("names")[0]) + assert.Equal(t, "Test 2", aahReq1.QueryArrayValue("names")[1]) + assert.True(t, len(aahReq1.QueryArrayValue("not-exists")) == 0) + assert.Equal(t, "100001", aahReq1.PathValue("userId")) + assert.Equal(t, "", aahReq1.PathValue("accountId")) + assert.Equal(t, 1, aahReq1.PathParams.Len()) // Form value form := url.Values{} @@ -132,27 +138,26 @@ func TestHTTPRequestParams(t *testing.T) { _ = req2.ParseForm() aahReq2 := AcquireRequest(req2) - aahReq2.Params.Form = req2.Form - - params2 := aahReq2.Params assert.NotNil(t, aahReq2.Body()) - assert.Equal(t, "welcome", params2.FormValue("username")) - assert.Equal(t, "welcome@welcome.com", params2.FormValue("email")) - assert.Equal(t, "Test1", params2.FormArrayValue("names")[0]) - assert.Equal(t, "Test 2 value", params2.FormArrayValue("names")[1]) - assert.True(t, len(params2.FormArrayValue("not-exists")) == 0) + assert.Equal(t, "welcome", aahReq2.FormValue("username")) + assert.Equal(t, "welcome@welcome.com", aahReq2.FormValue("email")) + assert.Equal(t, "Test1", aahReq2.FormArrayValue("names")[0]) + assert.Equal(t, "Test 2 value", aahReq2.FormArrayValue("names")[1]) + assert.True(t, len(aahReq2.FormArrayValue("not-exists")) == 0) + assert.Equal(t, 0, aahReq2.PathParams.Len()) ReleaseRequest(aahReq2) // File value - req3, _ := http.NewRequest("POST", "http://localhost:8080/user/registration", nil) + req3, _ := http.NewRequest("POST", "http://localhost:8080/user/registration", strings.NewReader(form.Encode())) req3.Header.Add(HeaderContentType, ContentTypeMultipartForm.String()) aahReq3 := AcquireRequest(req3) - aahReq3.Params.File = make(map[string][]*multipart.FileHeader) - aahReq3.Params.File["testfile.txt"] = []*multipart.FileHeader{{Filename: "testfile.txt"}} + aahReq3.Unwrap().MultipartForm = new(multipart.Form) + aahReq3.Unwrap().MultipartForm.File = make(map[string][]*multipart.FileHeader) + aahReq3.Unwrap().MultipartForm.File["testfile.txt"] = []*multipart.FileHeader{{Filename: "testfile.txt"}} f, fh, err := aahReq3.FormFile("testfile.txt") assert.Nil(t, f) assert.Equal(t, "testfile.txt", fh.Filename) - assert.Equal(t, "open : no such file or directory", err.Error()) + assert.True(t, strings.HasPrefix(err.Error(), "open :")) ReleaseRequest(aahReq3) } @@ -179,20 +184,25 @@ func TestHTTPRequestCookies(t *testing.T) { func TestRequestSchemeDerived(t *testing.T) { req := httptest.NewRequest("GET", "http://127.0.0.1:8080/welcome.html", nil) - scheme1 := identifyScheme(req) - assert.Equal(t, "http", scheme1) + assert.Equal(t, "http", Scheme(req)) + + req.Header.Set(HeaderXUrlScheme, "http") + assert.Equal(t, "http", Scheme(req)) + + req.Header.Set(HeaderXForwardedSsl, "on") + assert.Equal(t, "https", Scheme(req)) req.TLS = &tls.ConnectionState{} - scheme2 := identifyScheme(req) - assert.Equal(t, "https", scheme2) + assert.Equal(t, "https", Scheme(req)) + + req.Header.Set(HeaderXForwardedProtocol, "https") + assert.Equal(t, "https", Scheme(req)) req.Header.Set(HeaderXForwardedProto, "https") - scheme3 := identifyScheme(req) - assert.Equal(t, "https", scheme3) + assert.Equal(t, "https", Scheme(req)) req.Header.Set(HeaderXForwardedProto, "http") - scheme4 := identifyScheme(req) - assert.Equal(t, "http", scheme4) + assert.Equal(t, "http", Scheme(req)) } func TestRequestSaveFile(t *testing.T) { @@ -232,7 +242,7 @@ func TestRequestSaveFileFailsForNotFoundFile(t *testing.T) { _, err := aahReq.SaveFile("unknown-key", path) assert.NotNil(t, err) - assert.Equal(t, "ahttp: no such key/file: unknown-key", err.Error()) + assert.Equal(t, errors.New("http: no such file"), err) } func TestRequestSaveFileCannotCreateFile(t *testing.T) { @@ -244,55 +254,12 @@ func TestRequestSaveFileCannotCreateFile(t *testing.T) { assert.True(t, strings.HasPrefix(err.Error(), "ahttp: open /root/aah.txt")) } -func TestRequestSaveFiles(t *testing.T) { - aahReq, dir, teardown := setUpRequestSaveFiles(t) - defer teardown() - - sizes, errs := aahReq.SaveFiles("framework", dir) - assert.Nil(t, errs) - assert.Nil(t, sizes) - _, err := os.Stat(dir + "/aah") - assert.Nil(t, err) - _, err = os.Stat(dir + "/aah2") - assert.Nil(t, err) -} - -func TestRequestSaveFilesFailsVaildation(t *testing.T) { - aahReq, dir, teardown := setUpRequestSaveFiles(t) - defer teardown() - - // Empty key - sizes, errs := aahReq.SaveFiles("", dir) - assert.NotNil(t, errs) - assert.Equal(t, "ahttp: form file key, '' is empty", errs[0].Error()) - assert.Equal(t, int64(0), sizes[0]) - - // Empty directory - sizes, errs = aahReq.SaveFiles("key", "") - assert.NotNil(t, errs) - assert.Equal(t, "ahttp: destination path, '' is not a directory", errs[0].Error()) - assert.Equal(t, int64(0), sizes[0]) -} - -func TestRequestSaveFilesCannotCreateFile(t *testing.T) { - aahReq, _, teardown := setUpRequestSaveFiles(t) - defer teardown() - - sizes, errs := aahReq.SaveFiles("framework", "/root") - assert.NotNil(t, errs) - assert.Equal(t, int64(0), sizes[0]) - - errMsg := errs[0].Error() - assert.True(t, ("ahttp: open /root/aah: permission denied" == errMsg || - "ahttp: destination path, '/root' is not a directory" == errMsg)) -} - func TestRequestSaveFileForExistingFile(t *testing.T) { var buf bytes.Buffer size, err := saveFile(&buf, "testdata/file1.txt") assert.NotNil(t, err) - assert.Equal(t, "ahttp: open testdata/file1.txt: file exists", err.Error()) + assert.True(t, strings.HasPrefix(err.Error(), "ahttp: open testdata/file1.txt:")) assert.Equal(t, int64(0), size) } @@ -301,7 +268,8 @@ func TestRequestSaveFileForExistingFile(t *testing.T) { //___________________________________ func createRequestWithHost(host, remote string) *http.Request { - return &http.Request{Host: host, RemoteAddr: remote, Header: http.Header{}} + url, _ := url.Parse("http://localhost:8080/testpath") + return &http.Request{URL: url, Host: host, RemoteAddr: remote, Header: http.Header{}} } func setUpRequestSaveFile(t *testing.T) (*Request, string, func()) { @@ -315,12 +283,12 @@ func setUpRequestSaveFile(t *testing.T) (*Request, string, func()) { req, _ := http.NewRequest("POST", "http://localhost:8080", buf) req.Header.Add(HeaderContentType, multipartWriter.FormDataContentType()) aahReq := AcquireRequest(req) - aahReq.Params.File = make(map[string][]*multipart.FileHeader) _, header, err := req.FormFile("framework") assert.Nil(t, err) - aahReq.Params.File["framework"] = []*multipart.FileHeader{header} + aahReq.Unwrap().MultipartForm.File = make(map[string][]*multipart.FileHeader) + aahReq.Unwrap().MultipartForm.File["framework"] = []*multipart.FileHeader{header} path := "testdata/aah.txt" @@ -328,33 +296,3 @@ func setUpRequestSaveFile(t *testing.T) (*Request, string, func()) { _ = os.Remove(path) //Teardown } } - -func setUpRequestSaveFiles(t *testing.T) (*Request, string, func()) { - buf := new(bytes.Buffer) - multipartWriter := multipart.NewWriter(buf) - _, err := multipartWriter.CreateFormFile("framework", "aah") - assert.Nil(t, err) - _, err = multipartWriter.CreateFormFile("framework2", "aah2") - assert.Nil(t, err) - - ess.CloseQuietly(multipartWriter) - - req, _ := http.NewRequest("POST", "http://localhost:8080", buf) - req.Header.Add(HeaderContentType, multipartWriter.FormDataContentType()) - aahReq := AcquireRequest(req) - aahReq.Params.File = make(map[string][]*multipart.FileHeader) - - _, header, err := req.FormFile("framework") - assert.Nil(t, err) - _, header2, err := req.FormFile("framework2") - assert.Nil(t, err) - - aahReq.Params.File["framework"] = []*multipart.FileHeader{header, header2} - - dir := "testdata/upload" - - _ = ess.MkDirAll(dir, 0755) - return aahReq, dir, func() { - _ = os.RemoveAll(dir) - } -} diff --git a/response.go b/response.go index 6dc4807..6bee1ec 100644 --- a/response.go +++ b/response.go @@ -11,34 +11,6 @@ import ( "net" "net/http" "sync" - - "aahframework.org/essentials.v0" -) - -type ( - // ResponseWriter extends the `http.ResponseWriter` interface to implements - // aah framework response. - ResponseWriter interface { - http.ResponseWriter - - // Status returns the HTTP status of the request otherwise 0 - Status() int - - // BytesWritten returns the total number of bytes written - BytesWritten() int - - // Unwrap returns the original `ResponseWriter` - Unwrap() http.ResponseWriter - } - - // Response implements multiple interface (CloseNotifier, Flusher, - // Hijacker) and handy methods for aah framework. - Response struct { - w http.ResponseWriter - status int - wroteStatus bool - bytesWritten int - } ) var ( @@ -53,30 +25,34 @@ var ( _ ResponseWriter = (*Response)(nil) ) -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Package methods -//___________________________________ +// ResponseWriter extends the `http.ResponseWriter` interface to implements +// aah framework response. +type ResponseWriter interface { + http.ResponseWriter -// TODO for old method cleanup + // Status returns the HTTP status of the request otherwise 0 + Status() int -// GetResponseWriter method wraps given writer and returns the aah response writer. -// Deprecated use `AcquireResponseWriter` instead. -func GetResponseWriter(w http.ResponseWriter) ResponseWriter { - rw := responsePool.Get().(*Response) - rw.w = w - return rw -} + // BytesWritten returns the total number of bytes written + BytesWritten() int -// PutResponseWriter method puts response writer back to pool. -// Deprecated use `ReleaseResponseWriter` instead. -func PutResponseWriter(aw ResponseWriter) { - releaseResponse(aw.(*Response)) + // Unwrap returns the original `ResponseWriter` + Unwrap() http.ResponseWriter } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Response methods +// Response //___________________________________ +// Response implements multiple interface (CloseNotifier, Flusher, +// Hijacker) and handy methods for aah framework. +type Response struct { + w http.ResponseWriter + status int + wroteStatus bool + bytesWritten int +} + // Status method returns HTTP response status code. If status is not yet written // it reurns 0. func (r *Response) Status() int { @@ -99,9 +75,7 @@ func (r *Response) Header() http.Header { // Write method writes bytes into Response. func (r *Response) Write(b []byte) (int, error) { - r.setContentTypeIfNotSet(b) r.WriteHeader(http.StatusOK) - size, err := r.w.Write(b) r.bytesWritten += size return size, err @@ -114,7 +88,9 @@ func (r *Response) BytesWritten() int { // Close method closes the writer if possible. func (r *Response) Close() error { - ess.CloseQuietly(r.w) + if w, ok := r.w.(io.Closer); ok { + return w.Close() + } return nil } @@ -170,12 +146,6 @@ func (r *Response) Reset() { // Response Unexported methods //___________________________________ -func (r *Response) setContentTypeIfNotSet(b []byte) { - if ct := r.Header().Get(HeaderContentType); ess.IsStrEmpty(ct) { - r.Header().Set(HeaderContentType, http.DetectContentType(b)) - } -} - // releaseResponse method puts response back to pool. func releaseResponse(r *Response) { _ = r.Close() diff --git a/response_test.go b/response_test.go index 4bdbcb0..d3a32f3 100644 --- a/response_test.go +++ b/response_test.go @@ -18,8 +18,8 @@ import ( func TestHTTPResponseWriter(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - writer := GetResponseWriter(w) - defer PutResponseWriter(writer) + writer := AcquireResponseWriter(w) + defer ReleaseResponseWriter(writer) writer.WriteHeader(http.StatusOK) assert.Equal(t, http.StatusOK, writer.Status()) @@ -47,8 +47,8 @@ func TestHTTPNoStatusWritten(t *testing.T) { func TestHTTPMultipleStatusWritten(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - writer := GetResponseWriter(w) - defer PutResponseWriter(writer) + writer := AcquireResponseWriter(w) + defer ReleaseResponseWriter(writer) writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusAccepted) @@ -85,8 +85,8 @@ func TestHTTPHijackCall(t *testing.T) { func TestHTTPCallCloseNotifyAndFlush(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - writer := GetResponseWriter(w) - defer PutResponseWriter(writer) + writer := AcquireResponseWriter(w) + defer ReleaseResponseWriter(writer) _, _ = writer.Write([]byte("aah framework calling close notify and flush")) assert.Equal(t, 44, writer.BytesWritten()) diff --git a/version.go b/version.go index 8ccc555..aaa8b26 100644 --- a/version.go +++ b/version.go @@ -1,8 +1,8 @@ // Copyright (c) Jeevanandam M (https://github.com/jeevatkm) -// go-aah/ahttp source code and usage is governed by a MIT style +// aahframework.org/ahttp source code and usage is governed by a MIT style // license that can be found in the LICENSE file. package ahttp // Version no. of aah framework ahttp library -const Version = "0.10" +const Version = "0.11.0"