From 15e321ba37defc87db1e5d7479317204d7547b41 Mon Sep 17 00:00:00 2001 From: Yakov Dlougach Date: Tue, 11 Jun 2024 05:49:34 +0100 Subject: [PATCH] Make BaseURL insensitive to trailing slashes for metadata endpoint redirect. (#5458) * Make BaseURL insensitive to trailing slashes for metadata endpoint redirect. Signed-off-by: Yakov Dlougach * Lint renaming Signed-off-by: Yakov Dlougach --------- Signed-off-by: Yakov Dlougach --- flyteadmin/auth/handlers.go | 2 +- flyteadmin/auth/handlers_test.go | 70 ++++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/flyteadmin/auth/handlers.go b/flyteadmin/auth/handlers.go index bb03e1a654..0416245f8c 100644 --- a/flyteadmin/auth/handlers.go +++ b/flyteadmin/auth/handlers.go @@ -484,7 +484,7 @@ func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Re // See https://tools.ietf.org/html/rfc8414 for more information. func GetOIdCMetadataEndpointRedirectHandler(ctx context.Context, authCtx interfaces.AuthenticationContext) http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) { - metadataURL := authCtx.Options().UserAuth.OpenID.BaseURL.ResolveReference(authCtx.GetOIdCMetadataURL()) + metadataURL := authCtx.Options().UserAuth.OpenID.BaseURL.JoinPath("/").ResolveReference(authCtx.GetOIdCMetadataURL()) http.Redirect(writer, request, metadataURL.String(), http.StatusSeeOther) } } diff --git a/flyteadmin/auth/handlers_test.go b/flyteadmin/auth/handlers_test.go index 5428fb9b80..ee106e92cb 100644 --- a/flyteadmin/auth/handlers_test.go +++ b/flyteadmin/auth/handlers_test.go @@ -449,24 +449,60 @@ func TestGetHTTPRequestCookieToMetadataHandler_CustomHeader(t *testing.T) { func TestGetOIdCMetadataEndpointRedirectHandler(t *testing.T) { ctx := context.Background() - metadataPath := mustParseURL(t, OIdCMetadataEndpoint) - mockAuthCtx := mocks.AuthenticationContext{} - mockAuthCtx.OnOptions().Return(&config.Config{ - UserAuth: config.UserAuthConfig{ - OpenID: config.OpenIDOptions{ - BaseURL: stdConfig.URL{URL: mustParseURL(t, "http://www.google.com")}, - }, + type test struct { + name string + baseURL string + metadataPath string + expectedRedirectLocation string + } + tests := []test{ + { + name: "base_url_without_path", + baseURL: "http://www.google.com", + metadataPath: OIdCMetadataEndpoint, + expectedRedirectLocation: "http://www.google.com/.well-known/openid-configuration", }, - }) - - mockAuthCtx.OnGetOIdCMetadataURL().Return(&metadataPath) - handler := GetOIdCMetadataEndpointRedirectHandler(ctx, &mockAuthCtx) - req, err := http.NewRequest("GET", "/xyz", nil) - assert.NoError(t, err) - w := httptest.NewRecorder() - handler(w, req) - assert.Equal(t, http.StatusSeeOther, w.Code) - assert.Equal(t, "http://www.google.com/.well-known/openid-configuration", w.Header()["Location"][0]) + { + name: "base_url_with_path", + baseURL: "https://login.microsoftonline.com/abc/v2.0", + metadataPath: OIdCMetadataEndpoint, + expectedRedirectLocation: "https://login.microsoftonline.com/abc/v2.0/.well-known/openid-configuration", + }, + { + name: "base_url_with_trailing_slash_path", + baseURL: "https://login.microsoftonline.com/abc/v2.0/", + metadataPath: OIdCMetadataEndpoint, + expectedRedirectLocation: "https://login.microsoftonline.com/abc/v2.0/.well-known/openid-configuration", + }, + { + name: "absolute_metadata_path", + baseURL: "https://login.microsoftonline.com/abc/v2.0/", + metadataPath: "/.well-known/openid-configuration", + expectedRedirectLocation: "https://login.microsoftonline.com/.well-known/openid-configuration", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadataPath := mustParseURL(t, tt.metadataPath) + mockAuthCtx := mocks.AuthenticationContext{} + mockAuthCtx.OnOptions().Return(&config.Config{ + UserAuth: config.UserAuthConfig{ + OpenID: config.OpenIDOptions{ + BaseURL: stdConfig.URL{URL: mustParseURL(t, tt.baseURL)}, + }, + }, + }) + + mockAuthCtx.OnGetOIdCMetadataURL().Return(&metadataPath) + handler := GetOIdCMetadataEndpointRedirectHandler(ctx, &mockAuthCtx) + req, err := http.NewRequest("GET", "/xyz", nil) + assert.NoError(t, err) + w := httptest.NewRecorder() + handler(w, req) + assert.Equal(t, http.StatusSeeOther, w.Code) + assert.Equal(t, tt.expectedRedirectLocation, w.Header()["Location"][0]) + }) + } } func TestUserInfoForwardResponseHander(t *testing.T) {