From 7924da13a57ecbf3352d84e6d020012723b81fa1 Mon Sep 17 00:00:00 2001 From: Darshit Chanpura <35282393+DarshitChanpura@users.noreply.github.com> Date: Tue, 3 Oct 2023 10:32:15 -0400 Subject: [PATCH] Refactors reRequestAuthentication to call notifyIpAuthFailureListener before sending the response to the channel (#3411) Prior to this change, the ip auth failure listener was not called upon challengeAuthenticator check invocation, which caused AddressBasedRateLimiter to not be invoked. With this change AddressBasedRateLimiter will be invoked upon multiple wrong requests from an ip. Signed-off-by: Darshit Chanpura --- .../IpBruteForceAttacksPreventionTests.java | 27 +++++----- ...cksPreventionWithDomainChallengeTests.java | 32 ++++++++++++ .../cluster/LocalOpenSearchCluster.java | 2 +- .../jwt/AbstractHTTPJwtAuthenticator.java | 6 +-- .../auth/http/jwt/HTTPJwtAuthenticator.java | 6 +-- .../kerberos/HTTPSpnegoAuthenticator.java | 11 ++-- .../http/saml/AuthTokenProcessorHandler.java | 41 ++++----------- .../auth/http/saml/HTTPSamlAuthenticator.java | 19 ++++--- .../security/auth/BackendRegistry.java | 51 ++++++++++-------- .../security/auth/HTTPAuthenticator.java | 15 +++--- .../security/http/HTTPBasicAuthenticator.java | 6 +-- .../http/HTTPClientCertAuthenticator.java | 6 +-- .../security/http/HTTPProxyAuthenticator.java | 6 +-- .../http/OnBehalfOfAuthenticator.java | 6 +-- .../proxy/HTTPExtendedProxyAuthenticator.java | 6 --- .../http/saml/HTTPSamlAuthenticatorTest.java | 52 +++++++++++++------ .../limiting/AddressBasedRateLimiterTest.java | 21 ++++---- .../UserNameBasedRateLimiterTest.java | 21 ++++---- .../cache/DummyHTTPAuthenticator.java | 6 +-- 19 files changed, 181 insertions(+), 159 deletions(-) create mode 100644 src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java index bb16e0be1b..34e79613f6 100644 --- a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionTests.java @@ -12,7 +12,6 @@ import java.util.concurrent.TimeUnit; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,8 +35,8 @@ @RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class IpBruteForceAttacksPreventionTests { - private static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); - private static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); + static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS); + static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS); public static final int ALLOWED_TRIES = 3; public static final int TIME_WINDOW_SECONDS = 3; @@ -51,7 +50,7 @@ public class IpBruteForceAttacksPreventionTests { public static final String CLIENT_IP_8 = "127.0.0.8"; public static final String CLIENT_IP_9 = "127.0.0.9"; - private static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( + static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit( new RateLimiting("internal_authentication_backend_limiting").type("ip") .allowedTries(ALLOWED_TRIES) .timeWindowSeconds(TIME_WINDOW_SECONDS) @@ -60,13 +59,17 @@ public class IpBruteForceAttacksPreventionTests { .maxTrackedClients(500) ); - @ClassRule - public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) - .anonymousAuth(false) - .authFailureListeners(listener) - .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) - .users(USER_1, USER_2) - .build(); + @Rule + public LocalCluster cluster = createCluster(); + + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE) + .users(USER_1, USER_2) + .build(); + } @Rule public LogsRule logsRule = new LogsRule("org.opensearch.security.auth.BackendRegistry"); @@ -151,7 +154,7 @@ public void shouldReleaseIpAddressLock() throws InterruptedException { } } - private static void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { + void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) { var clientConfiguration = new TestRestClientConfiguration().username(user.getName()) .password("incorrect password") .sourceInetAddress(sourceIpAddress); diff --git a/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java new file mode 100644 index 0000000000..6159599119 --- /dev/null +++ b/src/integrationTest/java/org/opensearch/security/IpBruteForceAttacksPreventionWithDomainChallengeTests.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.security; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.junit.runner.RunWith; +import org.opensearch.test.framework.cluster.ClusterManager; +import org.opensearch.test.framework.cluster.LocalCluster; + +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; + +@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class IpBruteForceAttacksPreventionWithDomainChallengeTests extends IpBruteForceAttacksPreventionTests { + @Override + public LocalCluster createCluster() { + return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE) + .anonymousAuth(false) + .authFailureListeners(listener) + .authc(AUTHC_HTTPBASIC_INTERNAL) + .users(USER_1, USER_2) + .build(); + } +} diff --git a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java index c09127e592..77890a4645 100644 --- a/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java +++ b/src/integrationTest/java/org/opensearch/test/framework/cluster/LocalOpenSearchCluster.java @@ -344,7 +344,7 @@ public String toString() { String clusterManagerNodes = nodeByTypeToString(CLUSTER_MANAGER); String dataNodes = nodeByTypeToString(DATA); String clientNodes = nodeByTypeToString(CLIENT); - return "\nES Cluster " + return "\nOS Cluster " + clusterName + "\ncluster manager nodes: " + clusterManagerNodes diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index 4dab3c7740..6dbb7b7676 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -36,7 +36,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -236,11 +235,10 @@ public String[] extractRoles(JwtClaims claims) { protected abstract KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception; @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials authCredentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } public String getRequiredAudience() { diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index 03e385d5c0..338a490037 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -31,7 +31,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -171,11 +170,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, ""); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index 29f537e899..99bc23b746 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -49,7 +49,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.env.Environment; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -280,8 +279,7 @@ public GSSCredential run() throws GSSException { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse; XContentBuilder response = getNegotiateResponseBody(); @@ -291,16 +289,15 @@ public boolean reRequestAuthentication(final RestChannel channel, AuthCredential wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, EMPTY_STRING); } - if (creds == null || creds.getNativeCredentials() == null) { + if (credentials == null || credentials.getNativeCredentials() == null) { wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate"); } else { wwwAuthenticateResponse.addHeader( "WWW-Authenticate", - "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials()) + "Negotiate " + Base64.getEncoder().encodeToString((byte[]) credentials.getNativeCredentials()) ); } - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 3fee4a9444..7210ed5950 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -23,7 +23,6 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; -import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPathExpressionException; import com.fasterxml.jackson.core.JsonParseException; @@ -32,7 +31,6 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.base.Strings; import com.onelogin.saml2.authn.SamlResponse; -import com.onelogin.saml2.exception.SettingsException; import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.settings.Saml2Settings; import com.onelogin.saml2.util.Util; @@ -49,7 +47,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.joda.time.DateTime; -import org.xml.sax.SAXException; import org.opensearch.OpenSearchSecurityException; import org.opensearch.SpecialPermission; @@ -57,7 +54,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; import org.opensearch.core.rest.RestStatus; @@ -122,7 +118,7 @@ class AuthTokenProcessorHandler { } @SuppressWarnings("removal") - boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception { + BytesRestResponse handle(RestRequest restRequest) throws Exception { try { final SecurityManager sm = System.getSecurityManager(); @@ -130,11 +126,10 @@ boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exceptio sm.checkPermission(new SpecialPermission()); } - return AccessController.doPrivileged(new PrivilegedExceptionAction() { + return AccessController.doPrivileged(new PrivilegedExceptionAction() { @Override - public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, ParserConfigurationException, - SAXException, SettingsException { - return handleLowLevel(restRequest, restChannel); + public BytesRestResponse run() throws SamlConfigException, IOException { + return handleLowLevel(restRequest); } }); } catch (PrivilegedActionException e) { @@ -147,13 +142,11 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc } private AuthTokenProcessorAction.Response handleImpl( - RestRequest restRequest, - RestChannel restChannel, String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings - ) throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, SettingsException { + ) { if (token_log.isDebugEnabled()) { try { token_log.debug( @@ -188,8 +181,7 @@ private AuthTokenProcessorAction.Response handleImpl( } } - private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, IOException, - XPathExpressionException, ParserConfigurationException, SAXException, SettingsException { + private BytesRestResponse handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException { try { if (restRequest.getMediaType() != XContentType.JSON) { @@ -234,31 +226,18 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue()); } - AuthTokenProcessorAction.Response responseBody = this.handleImpl( - restRequest, - restChannel, - samlResponseBase64, - samlRequestId, - acsEndpoint, - saml2Settings - ); + AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings); if (responseBody == null) { - return false; + return null; } String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); - restChannel.sendResponse(authenticateResponse); - - return true; + return new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); - - BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); - restChannel.sendResponse(authenticateResponse); - return true; + return new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed"); } } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index cd6209952f..64d816fabf 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -56,7 +56,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.Destroyable; @@ -171,13 +170,15 @@ public String getType() { } @Override - public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authCredentials) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { try { - RestRequest restRequest = restChannel.request(); - Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path()); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); final String suffix = matcher.matches() ? matcher.group(2) : null; - if (API_AUTHTOKEN_SUFFIX.equals(suffix) && this.authTokenProcessorHandler.handle(restRequest, restChannel)) { - return true; + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { + final BytesRestResponse restResponse = this.authTokenProcessorHandler.handle(request); + if (restResponse != null) { + return restResponse; + } } Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached(); @@ -185,12 +186,10 @@ public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authenticateResponse.addHeader("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)); - restChannel.sendResponse(authenticateResponse); - - return true; + return authenticateResponse; } catch (Exception e) { log.error("Error in reRequestAuthentication()", e); - return false; + return null; } } diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index ad1406426b..35d4e05c4c 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -230,7 +230,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g User authenticatedUser = null; - AuthCredentials authCredenetials = null; + AuthCredentials authCredentials = null; HTTPAuthenticator firstChallengingHttpAuthenticator = null; @@ -272,7 +272,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - authCredenetials = ac; + authCredentials = ac; if (ac == null) { // no credentials found in request @@ -280,12 +280,18 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g continue; } - if (authDomain.isChallenge() && httpAuthenticator.reRequestAuthentication(channel, null)) { - auditLog.logFailedLogin("", false, null, request); - if (isTraceEnabled) { - log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + if (authDomain.isChallenge()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, null); + if (restResponse != null) { + auditLog.logFailedLogin("", false, null, request); + if (isTraceEnabled) { + log.trace("No 'Authorization' header, send 401 and 'WWW-Authenticate Basic'"); + } + notifyIpAuthFailureListeners(request, authCredentials); + channel.sendResponse(restResponse); + return false; } - return false; + } else { // no reRequest possible if (isTraceEnabled) { @@ -296,9 +302,12 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } else { org.apache.logging.log4j.ThreadContext.put("user", ac.getUsername()); if (!ac.isComplete()) { + final BytesRestResponse restResponse = httpAuthenticator.reRequestAuthentication(request, ac); // credentials found in request but we need another client challenge - if (httpAuthenticator.reRequestAuthentication(channel, ac)) { + if (restResponse != null) { // auditLog.logFailedLogin(ac.getUsername()+" ", request); --noauditlog + notifyIpAuthFailureListeners(request, ac); + channel.sendResponse(restResponse); return false; } else { // no reRequest possible @@ -376,7 +385,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("User still not authenticated after checking {} auth domains", restAuthDomains.size()); } - if (authCredenetials == null && anonymousAuthEnabled) { + if (authCredentials == null && anonymousAuthEnabled) { final String tenant = Utils.coalesce(request.header("securitytenant"), request.header("security_tenant")); User anonymousUser = new User(User.ANONYMOUS.getName(), new HashSet(User.ANONYMOUS.getRoles()), null); anonymousUser.setRequestedTenant(tenant); @@ -388,6 +397,7 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g } return true; } + BytesRestResponse challengeResponse = null; if (firstChallengingHttpAuthenticator != null) { @@ -395,31 +405,28 @@ && isBlocked(((InetSocketAddress) request.getHttpChannel().getRemoteAddress()).g log.debug("Rerequest with {}", firstChallengingHttpAuthenticator.getClass()); } - if (firstChallengingHttpAuthenticator.reRequestAuthentication(channel, null)) { + challengeResponse = firstChallengingHttpAuthenticator.reRequestAuthentication(request, null); + if (challengeResponse != null) { if (isDebugEnabled) { log.debug("Rerequest {} failed", firstChallengingHttpAuthenticator.getClass()); } - - log.warn( - "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), - remoteAddress - ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); - return false; } } log.warn( "Authentication finally failed for {} from {}", - authCredenetials == null ? null : authCredenetials.getUsername(), + authCredentials == null ? null : authCredentials.getUsername(), remoteAddress ); - auditLog.logFailedLogin(authCredenetials == null ? null : authCredenetials.getUsername(), false, null, request); + auditLog.logFailedLogin(authCredentials == null ? null : authCredentials.getUsername(), false, null, request); - notifyIpAuthFailureListeners(request, authCredenetials); + notifyIpAuthFailureListeners(request, authCredentials); - channel.sendResponse(new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed")); + channel.sendResponse( + challengeResponse != null + ? challengeResponse + : new BytesRestResponse(RestStatus.UNAUTHORIZED, "Authentication finally failed") + ); return false; } diff --git a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java index fa5065ef68..70b94d6dbf 100644 --- a/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java +++ b/src/main/java/org/opensearch/security/auth/HTTPAuthenticator.java @@ -28,7 +28,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.user.AuthCredentials; @@ -71,15 +71,14 @@ public interface HTTPAuthenticator { /** * If the {@code extractCredentials()} call was not successful or the authentication flow needs another roundtrip this method - * will be called. If the custom HTTP authenticator does not support this method is a no-op and false should be returned. - * + * will be called. If the custom HTTP authenticator does not support this method is a no-op and null response should be returned. * If the custom HTTP authenticator does support re-request authentication or supports authentication flows with multiple roundtrips - * then the response should be sent (through the channel) and true must be returned. + * then the response will be returned which can then be sent via response channel. * - * @param channel The rest channel to sent back the response via {@code channel.sendResponse()} + * @param request * @param credentials The credentials from the prior authentication attempt - * @return false if re-request is not supported/necessary, true otherwise. - * If true is returned {@code channel.sendResponse()} must be called so that the request completes. + * @return null if re-request is not supported/necessary, response object otherwise. + * If an object is returned {@code channel.sendResponse()} must be called so that the request completes. */ - boolean reRequestAuthentication(final RestChannel channel, AuthCredentials credentials); + BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials); } diff --git a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java index 4be83bc2e2..16a822df6d 100644 --- a/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPBasicAuthenticator.java @@ -34,7 +34,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auth.HTTPAuthenticator; @@ -65,11 +64,10 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "Unauthorized"); wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Basic realm=\"OpenSearch Security\""); - channel.sendResponse(wwwAuthenticateResponse); - return true; + return wwwAuthenticateResponse; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java index b1e5d4ef40..cea59f7311 100644 --- a/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPClientCertAuthenticator.java @@ -41,7 +41,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -98,8 +98,8 @@ public AuthCredentials extractCredentials(final RestRequest request, final Threa } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java index a58a842394..1db7b9769b 100644 --- a/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/HTTPProxyAuthenticator.java @@ -37,7 +37,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.support.ConfigConstants; @@ -89,8 +89,8 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java index c47e850b75..02077bed7c 100644 --- a/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/OnBehalfOfAuthenticator.java @@ -33,7 +33,7 @@ import org.opensearch.SpecialPermission; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.authtoken.jwt.EncryptionDecryptionUtil; @@ -241,8 +241,8 @@ public Boolean isRequestAllowed(final RestRequest request) { } @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } @Override diff --git a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java index ef20374d69..0423fecefe 100644 --- a/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java +++ b/src/main/java/org/opensearch/security/http/proxy/HTTPExtendedProxyAuthenticator.java @@ -37,7 +37,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.Strings; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.security.http.HTTPProxyAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -84,11 +83,6 @@ public AuthCredentials extractCredentials(final RestRequest request, ThreadConte return credentials.markComplete(); } - @Override - public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) { - return false; - } - @Override public String getType() { return "extended-proxy"; diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index ff9ec19b09..2594388128 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -47,6 +47,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; @@ -141,7 +142,8 @@ public void basicTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -188,7 +190,8 @@ public void decryptAssertionsTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -236,7 +239,8 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -287,7 +291,8 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -338,7 +343,8 @@ public void shouldNotEscapeSamlEntities() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -389,7 +395,8 @@ public void shouldNotTrimWhitespaceInJwtRoles() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -436,7 +443,8 @@ public void testMetadataBody() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -501,7 +509,8 @@ public void unsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -552,7 +561,8 @@ public void badUnsolicitedSsoTest() throws Exception { ); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status()); } @@ -584,7 +594,8 @@ public void wrongCertTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -613,7 +624,8 @@ public void noSignatureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); Assert.assertEquals(401, tokenRestChannel.response.status().getStatus()); } @@ -646,7 +658,8 @@ public void rolesTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -693,7 +706,8 @@ public void idpEndpointWithQueryStringTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -747,7 +761,8 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -850,7 +865,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); Assert.assertNull(restChannel.response); @@ -870,7 +886,8 @@ public void initialConnectionFailureTest() throws Exception { RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders); TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest); - samlAuthenticator.reRequestAuthentication(tokenRestChannel, null); + authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null); + tokenRestChannel.sendResponse(authenticatorResponse); String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content())); HashMap response = DefaultObjectMapper.objectMapper.readValue( @@ -893,7 +910,8 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); TestRestChannel restChannel = new TestRestChannel(restRequest); - samlAuthenticator.reRequestAuthentication(restChannel, null); + final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null); + restChannel.sendResponse(authenticatorResponse); List wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate"); diff --git a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java index 827bfa24b6..69ddc5c03a 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/AddressBasedRateLimiterTest.java @@ -20,28 +20,27 @@ import org.junit.Test; import org.opensearch.common.settings.Settings; -import org.opensearch.security.user.AuthCredentials; + +import java.net.InetAddress; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class AddressBasedRateLimiterTest { - private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; - @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); + AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertFalse(rateLimiter.isBlocked("a")); - rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); - assertTrue(rateLimiter.isBlocked("a")); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); + assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); } } diff --git a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java index e42d2bd1b8..a8285c42a7 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/UserNameBasedRateLimiterTest.java @@ -17,30 +17,31 @@ package org.opensearch.security.auth.limiting; -import java.net.InetAddress; - import org.junit.Test; import org.opensearch.common.settings.Settings; +import org.opensearch.security.user.AuthCredentials; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class UserNameBasedRateLimiterTest { + private final static byte[] PASSWORD = new byte[] { '1', '2', '3' }; + @Test public void simpleTest() throws Exception { Settings settings = Settings.builder().put("allowed_tries", 3).build(); - AddressBasedRateLimiter rateLimiter = new AddressBasedRateLimiter(settings, null); + UserNameBasedRateLimiter rateLimiter = new UserNameBasedRateLimiter(settings, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertFalse(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); - rateLimiter.onAuthFailure(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }), null, null); - assertTrue(rateLimiter.isBlocked(InetAddress.getByAddress(new byte[] { 1, 2, 3, 4 }))); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertFalse(rateLimiter.isBlocked("a")); + rateLimiter.onAuthFailure(null, new AuthCredentials("a", PASSWORD), null); + assertTrue(rateLimiter.isBlocked("a")); } } diff --git a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java index 55c2e789c6..37ac45080b 100644 --- a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java +++ b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java @@ -16,7 +16,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestChannel; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.security.auth.HTTPAuthenticator; import org.opensearch.security.user.AuthCredentials; @@ -39,8 +39,8 @@ public AuthCredentials extractCredentials(RestRequest request, ThreadContext con } @Override - public boolean reRequestAuthentication(RestChannel channel, AuthCredentials credentials) { - return false; + public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) { + return null; } public static long getCount() {