diff --git a/middleware/grpc_middleware_test.go b/middleware/grpc_middleware_test.go index 398648c149..542b6f49b1 100644 --- a/middleware/grpc_middleware_test.go +++ b/middleware/grpc_middleware_test.go @@ -50,7 +50,7 @@ func testClient(t *testing.T, l *bufconn.Listener, dialOpts ...grpc.DialOption) func testTokenCheckServer(t *testing.T) *httptest.Server { s := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("authorization") != "Bearer correct token" { + if r.Header.Get("Authorization") != "bearer correct token" { t.Logf("denied request %+v", r) w.WriteHeader(http.StatusForbidden) return @@ -77,7 +77,7 @@ func writeTestConfig(t *testing.T, pattern string, content string) string { type testToken string func (t testToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { - return map[string]string{"authorization": "Bearer " + string(t)}, nil + return map[string]string{"Authorization": "bearer " + string(t)}, nil } func (t testToken) RequireTransportSecurity() bool { return false } diff --git a/pipeline/authn/authenticator_bearer_token.go b/pipeline/authn/authenticator_bearer_token.go index 944ff68cb5..673dd0d7dc 100644 --- a/pipeline/authn/authenticator_bearer_token.go +++ b/pipeline/authn/authenticator_bearer_token.go @@ -141,6 +141,11 @@ func (a *AuthenticatorBearerToken) Authenticate(r *http.Request, session *Authen return errors.WithStack(ErrAuthenticatorNotResponsible) } + if r.Header == nil { + r.Header = make(http.Header) + } + r.Header.Set("Authorization", "bearer "+token) + body, err := forwardRequestToSessionStore(a.client, r, cf) if err != nil { return err diff --git a/pipeline/authn/authenticator_bearer_token_test.go b/pipeline/authn/authenticator_bearer_token_test.go index 4b02efabfd..dec773d527 100644 --- a/pipeline/authn/authenticator_bearer_token_test.go +++ b/pipeline/authn/authenticator_bearer_token_test.go @@ -38,6 +38,7 @@ func TestAuthenticatorBearerToken(t *testing.T) { t.Run("method=authenticate", func(t *testing.T) { for k, tc := range []struct { d string + token string r *http.Request setup func(*testing.T, *httprouter.Router) router func(http.ResponseWriter, *http.Request) @@ -96,6 +97,54 @@ func TestAuthenticatorBearerToken(t *testing.T) { Extra: map[string]interface{}{"foo": "bar"}, }, }, + { + d: "should pass because session token was provided in the correct custom header", + token: "custom-header-token-value", + r: &http.Request{Header: http.Header{"X-Custom-Header": {"custom-header-token-value"}}, URL: &url.URL{Path: ""}}, + router: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Authorization"), "bearer custom-header-token-value") + w.WriteHeader(200) + w.Write([]byte(`{"sub": "123", "extra": {"foo": "bar"}}`)) + }, + config: []byte(`{"token_from": {"header": "X-Custom-Header"}}`), + expectErr: false, + expectSess: &AuthenticationSession{ + Subject: "123", + Extra: map[string]interface{}{"foo": "bar"}, + }, + }, + { + d: "should pass because session token was provided in the correct custom query parameter", + token: "query-param-token-value", + r: &http.Request{Header: http.Header{}, URL: &url.URL{Path: "", RawQuery: "custom-query-param=query-param-token-value"}}, + router: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Authorization"), "bearer query-param-token-value") + w.WriteHeader(200) + w.Write([]byte(`{"sub": "123", "extra": {"foo": "bar"}}`)) + }, + config: []byte(`{"token_from": {"query_parameter": "custom-query-param"}}`), + expectErr: false, + expectSess: &AuthenticationSession{ + Subject: "123", + Extra: map[string]interface{}{"foo": "bar"}, + }, + }, + { + d: "should pass because session token was provided in the correct cookie", + token: "cooke-token-value", + r: &http.Request{Header: http.Header{"Cookie": {"custom-cookie-name=cooke-token-value"}}, URL: &url.URL{Path: ""}}, + router: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Authorization"), "bearer cooke-token-value") + w.WriteHeader(200) + w.Write([]byte(`{"sub": "123", "extra": {"foo": "bar"}}`)) + }, + config: []byte(`{"token_from": {"cookie": "custom-cookie-name"}}`), + expectErr: false, + expectSess: &AuthenticationSession{ + Subject: "123", + Extra: map[string]interface{}{"foo": "bar"}, + }, + }, { d: "should pass through method, path, and headers to auth server; should NOT pass through query parameters by default for backwards compatibility", r: &http.Request{Header: http.Header{"Authorization": {"bearer zyx"}}, URL: &url.URL{Path: "/users/123", RawQuery: "query=string"}, Method: "PUT"}, @@ -308,9 +357,12 @@ func TestAuthenticatorBearerToken(t *testing.T) { tc.config, _ = sjson.SetBytes(tc.config, "check_session_url", testCheckSessionUrl.String()) sess := new(AuthenticationSession) - originalHeaders := http.Header{} + expectedHeaders := http.Header{} for k, v := range tc.r.Header { - originalHeaders[k] = v + expectedHeaders[k] = v + } + if tc.token != "" { + expectedHeaders.Set("Authorization", "bearer "+tc.token) } err = pipelineAuthenticator.Authenticate(tc.r, sess, tc.config, nil) @@ -323,7 +375,7 @@ func TestAuthenticatorBearerToken(t *testing.T) { require.NoError(t, err) } - require.True(t, reflect.DeepEqual(tc.r.Header, originalHeaders)) + require.True(t, reflect.DeepEqual(tc.r.Header, expectedHeaders)) if tc.expectSess != nil { assert.Equal(t, tc.expectSess, sess)