diff --git a/rest/cors_test.go b/rest/cors_test.go new file mode 100644 index 0000000..09bbbc4 --- /dev/null +++ b/rest/cors_test.go @@ -0,0 +1,43 @@ +package rest + +import ( + "net/http" + "testing" + + "github.com/ant0ine/go-json-rest/rest/test" +) + +func TestCorsMiddlewareEmptyAccessControlRequestHeaders(t *testing.T) { + api := NewApi() + + // the middleware to test + api.Use(&CorsMiddleware{ + OriginValidator: func(_ string, _ *Request) bool { + return true + }, + AllowedMethods: []string{ + "GET", + "POST", + "PUT", + }, + AllowedHeaders: []string{ + "Origin", + "Referer", + }, + }) + + // wrap all + handler := api.MakeHandler() + + req, _ := http.NewRequest("OPTIONS", "http://localhost", nil) + req.Header.Set("Origin", "http://another.host") + req.Header.Set("Access-Control-Request-Method", "PUT") + req.Header.Set("Access-Control-Request-Headers", "") + + recorded := test.RunRequest(t, handler, req) + t.Logf("recorded: %+v\n", recorded.Recorder) + recorded.CodeIs(200) + recorded.HeaderIs("Access-Control-Allow-Methods", "GET,POST,PUT") + recorded.HeaderIs("Access-Control-Allow-Headers", "Origin,Referer") + recorded.HeaderIs("Access-Control-Allow-Origin", "http://another.host") +} diff --git a/rest/request.go b/rest/request.go index 9d1d792..f3113ef 100644 --- a/rest/request.go +++ b/rest/request.go @@ -120,6 +120,9 @@ func (r *Request) GetCorsInfo() *CorsInfo { reqHeaders := []string{} rawReqHeaders := r.Header[http.CanonicalHeaderKey("Access-Control-Request-Headers")] for _, rawReqHeader := range rawReqHeaders { + if len(rawReqHeader) == 0 { + continue + } // net/http does not handle comma delimited headers for us for _, reqHeader := range strings.Split(rawReqHeader, ",") { reqHeaders = append(reqHeaders, http.CanonicalHeaderKey(strings.TrimSpace(reqHeader))) diff --git a/rest/request_test.go b/rest/request_test.go index 78c0a2c..4186fee 100644 --- a/rest/request_test.go +++ b/rest/request_test.go @@ -148,3 +148,41 @@ func TestCorsInfoPreflightCors(t *testing.T) { t.Error("OriginUrl must be set") } } + +func TestCorsInfoEmptyAccessControlRequestHeaders(t *testing.T) { + req := defaultRequest("OPTIONS", "http://localhost", nil, t) + req.Request.Header.Set("Origin", "http://another.host") + + // make it a preflight request + req.Request.Header.Set("Access-Control-Request-Method", "PUT") + + // WebKit based browsers may send `Access-Control-Request-Headers:` with + // no value, in which case, the header will be present in requests + // Header map, but its value is an empty string. + req.Request.Header.Set("Access-Control-Request-Headers", "") + corsInfo := req.GetCorsInfo() + if corsInfo == nil { + t.Error("Expected non nil CorsInfo") + } + if corsInfo.IsCors == false { + t.Error("This is a CORS request") + } + if len(corsInfo.AccessControlRequestHeaders) > 0 { + t.Error("Access-Control-Request-Headers should have been removed") + } + + req.Request.Header.Set("Access-Control-Request-Headers", "") + corsInfo = req.GetCorsInfo() + if corsInfo == nil { + t.Error("Expected non nil CorsInfo") + } + if corsInfo.IsCors == false { + t.Error("This is a CORS request") + } + if corsInfo.IsPreflight == false { + t.Error("This is a Preflight request") + } + if len(corsInfo.AccessControlRequestHeaders) > 0 { + t.Error("Empty Access-Control-Request-Headers header should have been removed") + } +}