Skip to content

Commit

Permalink
OnBehalfOf tokens odds and ends (opensearch-project#3593)
Browse files Browse the repository at this point in the history
I was doing some inspection around the OBO feature and noticed some
items to clean up.

Signed-off-by: Peter Nied <petern@amazon.com>
  • Loading branch information
peternied authored and RyanL1997 committed Nov 2, 2023
1 parent e892c01 commit caa78c7
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void shouldLoadDefaultConfiguration() {
Awaitility.await().alias("Load default configuration").until(() -> client.getAuthInfo().getStatusCode(), equalTo(200));
}
try (TestRestClient client = cluster.getRestClient(ADMIN_USER_NAME, DEFAULT_PASSWORD)) {
client.assertCorrectCredentials(ADMIN_USER_NAME);
client.confirmCorrectCredentials(ADMIN_USER_NAME);
HttpResponse response = client.get("/_plugins/_security/api/internalusers");
response.assertStatusCode(200);
Map<String, Object> users = response.getBodyAs(Map.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ public void shouldCreateUserViaRestApi_success() {
assertThat(httpResponse.getStatusCode(), equalTo(201));
}
try (TestRestClient client = cluster.getRestClient(USER_ADMIN)) {
client.assertCorrectCredentials(USER_ADMIN.getName());
client.confirmCorrectCredentials(USER_ADMIN.getName());
}
try (TestRestClient client = cluster.getRestClient(ADDITIONAL_USER_1, ADDITIONAL_PASSWORD_1)) {
client.assertCorrectCredentials(ADDITIONAL_USER_1);
client.confirmCorrectCredentials(ADDITIONAL_USER_1);
}
}

Expand Down Expand Up @@ -160,10 +160,10 @@ public void shouldCreateUserViaRestApiWhenAdminIsAuthenticatedViaCertificate_pos
httpResponse.assertStatusCode(201);
}
try (TestRestClient client = cluster.getRestClient(USER_ADMIN)) {
client.assertCorrectCredentials(USER_ADMIN.getName());
client.confirmCorrectCredentials(USER_ADMIN.getName());
}
try (TestRestClient client = cluster.getRestClient(ADDITIONAL_USER_2, ADDITIONAL_PASSWORD_2)) {
client.assertCorrectCredentials(ADDITIONAL_USER_2);
client.confirmCorrectCredentials(ADDITIONAL_USER_2);
}
}

Expand All @@ -189,10 +189,10 @@ public void shouldStillWorkAfterUpdateOfSecurityConfig() {
cluster.updateUserConfiguration(users);

try (TestRestClient client = cluster.getRestClient(USER_ADMIN)) {
client.assertCorrectCredentials(USER_ADMIN.getName());
client.confirmCorrectCredentials(USER_ADMIN.getName());
}
try (TestRestClient client = cluster.getRestClient(newUser)) {
client.assertCorrectCredentials(newUser.getName());
client.confirmCorrectCredentials(newUser.getName());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void shouldAuthenticateUserWithCertificate_positiveUserSpoke() {
CertificateData userSpockCertificate = TEST_CERTIFICATES.issueUserCertificate(BACKEND_ROLE_BRIDGE, USER_SPOCK);
try (TestRestClient client = cluster.getRestClient(userSpockCertificate)) {

client.assertCorrectCredentials(USER_SPOCK);
client.confirmCorrectCredentials(USER_SPOCK);
}
}

Expand All @@ -98,7 +98,7 @@ public void shouldAuthenticateUserWithCertificate_positiveUserKirk() {
CertificateData userSpockCertificate = TEST_CERTIFICATES.issueUserCertificate(BACKEND_ROLE_BRIDGE, USER_KIRK);
try (TestRestClient client = cluster.getRestClient(userSpockCertificate)) {

client.assertCorrectCredentials(USER_KIRK);
client.confirmCorrectCredentials(USER_KIRK);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import javax.crypto.SecretKey;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import org.apache.http.Header;
import org.apache.http.HttpStatus;
import org.apache.http.message.BasicHeader;
import org.junit.Assert;
import io.jsonwebtoken.security.Keys;

import org.junit.ClassRule;
import org.junit.Test;
Expand All @@ -42,11 +39,10 @@
import org.opensearch.test.framework.cluster.TestRestClient;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.junit.Assert.assertTrue;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.contains;
import static org.opensearch.security.support.ConfigConstants.SECURITY_ALLOW_DEFAULT_INIT_SECURITYINDEX;
import static org.opensearch.security.support.ConfigConstants.SECURITY_RESTAPI_ROLES_ENABLED;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
Expand All @@ -60,6 +56,7 @@ public class OnBehalfOfJwtAuthenticationTest {

static final TestSecurityConfig.User ADMIN_USER = new TestSecurityConfig.User("admin").roles(ALL_ACCESS);

private static final String CREATE_OBO_TOKEN_PATH = "_plugins/_security/api/generateonbehalfoftoken";
private static Boolean oboEnabled = true;
private static final String signingKey = Base64.getEncoder()
.encodeToString(
Expand Down Expand Up @@ -143,7 +140,7 @@ public void shouldNotAuthenticateForUsingOBOTokenToAccessOBOEndpoint() {
Header adminOboAuthHeader = new BasicHeader("Authorization", "Bearer " + oboToken);

try (TestRestClient client = cluster.getRestClient(adminOboAuthHeader)) {
TestRestClient.HttpResponse response = client.getOnBehalfOfToken(OBO_DESCRIPTION, adminOboAuthHeader);
TestRestClient.HttpResponse response = client.postJson(CREATE_OBO_TOKEN_PATH, OBO_DESCRIPTION);
response.assertStatusCode(HttpStatus.SC_UNAUTHORIZED);
}
}
Expand All @@ -154,7 +151,7 @@ public void shouldNotAuthenticateForUsingOBOTokenToAccessAccountEndpoint() {
Header adminOboAuthHeader = new BasicHeader("Authorization", "Bearer " + oboToken);

try (TestRestClient client = cluster.getRestClient(adminOboAuthHeader)) {
TestRestClient.HttpResponse response = client.changeInternalUserPassword(CURRENT_AND_NEW_PASSWORDS, adminOboAuthHeader);
TestRestClient.HttpResponse response = client.putJson("_plugins/_security/api/account", CURRENT_AND_NEW_PASSWORDS);
response.assertStatusCode(HttpStatus.SC_UNAUTHORIZED);
}
}
Expand All @@ -177,54 +174,46 @@ public void shouldNotAuthenticateForNonAdminUserWithoutOBOPermission() {
public void shouldNotIncludeRolesFromHostMappingInOBOToken() {
String oboToken = generateOboToken(OBO_USER_NAME_WITH_HOST_MAPPING, DEFAULT_PASSWORD);

SecretKey key = Keys.hmacShaKeyFor(Base64.getDecoder().decode(signingKey));

Claims claims = Jwts.parser().verifyWith(key).build().parseSignedClaims(oboToken).getPayload();
Claims claims = Jwts.parser().setSigningKey(Base64.getDecoder().decode(signingKey)).build().parseClaimsJws(oboToken).getBody();

Object er = claims.get("er");
EncryptionDecryptionUtil encryptionDecryptionUtil = new EncryptionDecryptionUtil(encryptionKey);
String rolesClaim = encryptionDecryptionUtil.decrypt(er.toString());
List<String> roles = Arrays.stream(rolesClaim.split(","))
.map(String::trim)
.filter(s -> !s.isEmpty())
.collect(Collectors.toUnmodifiableList());
Set<String> roles = Arrays.stream(rolesClaim.split(",")).map(String::trim).filter(s -> !s.isEmpty()).collect(Collectors.toSet());

Assert.assertFalse(roles.contains("host_mapping_role"));
assertThat(roles, equalTo(HOST_MAPPING_OBO_USER.getRoleNames()));
assertThat(roles, not(contains("host_mapping_role")));
}

@Test
public void shouldNotAuthenticateWithInvalidDurationSeconds() {
try (TestRestClient client = cluster.getRestClient(ADMIN_USER_NAME, DEFAULT_PASSWORD)) {
client.assertCorrectCredentials(ADMIN_USER_NAME);
client.confirmCorrectCredentials(ADMIN_USER_NAME);
TestRestClient.HttpResponse response = client.postJson(OBO_ENDPOINT_PREFIX, OBO_DESCRIPTION_WITH_INVALID_DURATIONSECONDS);
response.assertStatusCode(HttpStatus.SC_BAD_REQUEST);
Map<String, Object> oboEndPointResponse = (Map<String, Object>) response.getBodyAs(Map.class);
assertTrue(oboEndPointResponse.containsValue("durationSeconds must be an integer."));
assertThat(response.getTextFromJsonBody("/error"), equalTo("durationSeconds must be an integer."));
}
}

@Test
public void shouldNotAuthenticateWithInvalidAPIParameter() {
try (TestRestClient client = cluster.getRestClient(ADMIN_USER_NAME, DEFAULT_PASSWORD)) {
client.assertCorrectCredentials(ADMIN_USER_NAME);
client.confirmCorrectCredentials(ADMIN_USER_NAME);
TestRestClient.HttpResponse response = client.postJson(OBO_ENDPOINT_PREFIX, OBO_DESCRIPTION_WITH_INVALID_PARAMETERS);
response.assertStatusCode(HttpStatus.SC_BAD_REQUEST);
Map<String, Object> oboEndPointResponse = (Map<String, Object>) response.getBodyAs(Map.class);
assertTrue(oboEndPointResponse.containsValue("Unrecognized parameter: invalidParameter"));
assertThat(response.getTextFromJsonBody("/error"), equalTo("Unrecognized parameter: invalidParameter"));
}
}

private String generateOboToken(String username, String password) {
try (TestRestClient client = cluster.getRestClient(username, password)) {
client.assertCorrectCredentials(username);
client.confirmCorrectCredentials(username);
TestRestClient.HttpResponse response = client.postJson(OBO_ENDPOINT_PREFIX, OBO_TOKEN_REASON);
response.assertStatusCode(HttpStatus.SC_OK);
Map<String, Object> oboEndPointResponse = (Map<String, Object>) response.getBodyAs(Map.class);
assertThat(
oboEndPointResponse,
allOf(aMapWithSize(3), hasKey("user"), hasKey("authenticationToken"), hasKey("durationSeconds"))
);
return oboEndPointResponse.get("authenticationToken").toString();
assertThat(response.getTextFromJsonBody("/user"), notNullValue());
assertThat(response.getTextFromJsonBody("/authenticationToken"), notNullValue());
assertThat(response.getTextFromJsonBody("/durationSeconds"), notNullValue());
return response.getTextFromJsonBody("/authenticationToken").toString();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,37 +135,6 @@ public void confirmCorrectCredentials(String expectedUserName) {
assertThat(message, username, equalTo(expectedUserName));
}

public HttpResponse getOnBehalfOfToken(String jsonData, Header... headers) {
try {
HttpPost httpPost = new HttpPost(
new URIBuilder(getHttpServerUri() + "/_plugins/_security/api/generateonbehalfoftoken?pretty").build()
);
httpPost.setEntity(toStringEntity(jsonData));
return executeRequest(httpPost, mergeHeaders(CONTENT_TYPE_JSON, headers));
} catch (URISyntaxException ex) {
throw new RuntimeException("Incorrect URI syntax", ex);
}
}

public HttpResponse changeInternalUserPassword(String jsonData, Header... headers) {
try {
HttpPut httpPut = new HttpPut(new URIBuilder(getHttpServerUri() + "/_plugins/_security/api/account?pretty").build());
httpPut.setEntity(toStringEntity(jsonData));
return executeRequest(httpPut, mergeHeaders(CONTENT_TYPE_JSON, headers));
} catch (URISyntaxException ex) {
throw new RuntimeException("Incorrect URI syntax", ex);
}
}

public void assertCorrectCredentials(String expectedUserName) {
HttpResponse response = getAuthInfo();
assertThat(response, notNullValue());
response.assertStatusCode(200);
String username = response.getTextFromJsonBody("/user_name");
String message = String.format("Expected user name is '%s', but was '%s'", expectedUserName, username);
assertThat(message, username, equalTo(expectedUserName));
}

public HttpResponse head(String path, Header... headers) {
return executeRequest(new HttpHead(getHttpServerUri() + "/" + path), headers);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -58,33 +57,25 @@ public class CreateOnBehalfOfTokenAction extends BaseRestHandler {

private ConfigModel configModel;

private DynamicConfigModel dcm;

public static final Integer OBO_DEFAULT_EXPIRY_SECONDS = 5 * 60;
public static final Integer OBO_MAX_EXPIRY_SECONDS = 10 * 60;

public static final String DEFAULT_SERVICE = "self-issued";

protected final Logger log = LogManager.getLogger(this.getClass());

private static final Set<String> RECOGNIZED_PARAMS = new HashSet<>(
Arrays.asList("durationSeconds", "description", "roleSecurityMode", "service")
);

@Subscribe
public void onConfigModelChanged(ConfigModel configModel) {
public void onConfigModelChanged(final ConfigModel configModel) {
this.configModel = configModel;
}

@Subscribe
public void onDynamicConfigModelChanged(DynamicConfigModel dcm) {
this.dcm = dcm;
public void onDynamicConfigModelChanged(final DynamicConfigModel dcm) {
final Settings settings = dcm.getDynamicOnBehalfOfSettings();

Settings settings = dcm.getDynamicOnBehalfOfSettings();

Boolean enabled = Boolean.parseBoolean(settings.get("enabled"));
String signingKey = settings.get("signing_key");
String encryptionKey = settings.get("encryption_key");
final Boolean enabled = Boolean.parseBoolean(settings.get("enabled"));
final String signingKey = settings.get("signing_key");
final String encryptionKey = settings.get("encryption_key");

if (!Boolean.FALSE.equals(enabled) && signingKey != null && encryptionKey != null) {
this.vendor = new JwtVendor(settings, Optional.empty());
Expand All @@ -109,7 +100,7 @@ public List<Route> routes() {
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
protected RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
switch (request.method()) {
case POST:
return handlePost(request, client);
Expand All @@ -118,10 +109,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
}

private RestChannelConsumer handlePost(RestRequest request, NodeClient client) throws IOException {
private RestChannelConsumer handlePost(final RestRequest request, final NodeClient client) throws IOException {
return new RestChannelConsumer() {
@Override
public void accept(RestChannel channel) throws Exception {
public void accept(final RestChannel channel) throws Exception {
final XContentBuilder builder = channel.newBuilder();
BytesRestResponse response;
try {
Expand All @@ -141,18 +132,14 @@ public void accept(RestChannel channel) throws Exception {

validateRequestParameters(requestBody);

Integer tokenDuration = parseAndValidateDurationSeconds(requestBody.get("durationSeconds"));
Integer tokenDuration = parseAndValidateDurationSeconds(requestBody.get(InputParameters.DURATION.paramName));
tokenDuration = Math.min(tokenDuration, OBO_MAX_EXPIRY_SECONDS);

final String description = (String) requestBody.getOrDefault("description", null);

final Boolean roleSecurityMode = Optional.ofNullable(requestBody.get("roleSecurityMode"))
.map(value -> (Boolean) value)
.orElse(true); // Default to false if null
final String description = (String) requestBody.getOrDefault(InputParameters.DESCRIPTION.paramName, null);

final String service = (String) requestBody.getOrDefault("service", DEFAULT_SERVICE);
final String service = (String) requestBody.getOrDefault(InputParameters.SERVICE.paramName, DEFAULT_SERVICE);
final User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
Set<String> mappedRoles = mapRoles(user);
final Set<String> mappedRoles = mapRoles(user);

builder.startObject();
builder.field("user", user.getName());
Expand All @@ -164,14 +151,14 @@ public void accept(RestChannel channel) throws Exception {
tokenDuration,
mappedRoles.stream().collect(Collectors.toList()),
user.getRoles().stream().collect(Collectors.toList()),
roleSecurityMode
false
);
builder.field("authenticationToken", token);
builder.field("durationSeconds", tokenDuration);
builder.endObject();

response = new BytesRestResponse(RestStatus.OK, builder);
} catch (IllegalArgumentException iae) {
} catch (final IllegalArgumentException iae) {
builder.startObject().field("error", iae.getMessage()).endObject();
response = new BytesRestResponse(RestStatus.BAD_REQUEST, builder);
} catch (final Exception exception) {
Expand All @@ -187,19 +174,32 @@ public void accept(RestChannel channel) throws Exception {
};
}

private enum InputParameters {
DURATION("durationSeconds"),
DESCRIPTION("description"),
SERVICE("service");

final String paramName;

private InputParameters(final String paramName) {
this.paramName = paramName;
}
}

private Set<String> mapRoles(final User user) {
return this.configModel.mapSecurityRoles(user, null);
}

private void validateRequestParameters(Map<String, Object> requestBody) throws IllegalArgumentException {
for (String key : requestBody.keySet()) {
if (!RECOGNIZED_PARAMS.contains(key)) {
throw new IllegalArgumentException("Unrecognized parameter: " + key);
}
private void validateRequestParameters(final Map<String, Object> requestBody) throws IllegalArgumentException {
for (final String key : requestBody.keySet()) {
Arrays.stream(InputParameters.values())
.filter(param -> param.paramName.equalsIgnoreCase(key))
.findAny()
.orElseThrow(() -> new IllegalArgumentException("Unrecognized parameter: " + key));
}
}

private Integer parseAndValidateDurationSeconds(Object durationObj) throws IllegalArgumentException {
private Integer parseAndValidateDurationSeconds(final Object durationObj) throws IllegalArgumentException {
if (durationObj == null) {
return OBO_DEFAULT_EXPIRY_SECONDS;
}
Expand All @@ -209,7 +209,7 @@ private Integer parseAndValidateDurationSeconds(Object durationObj) throws Illeg
} else if (durationObj instanceof String) {
try {
return Integer.parseInt((String) durationObj);
} catch (NumberFormatException ignored) {}
} catch (final NumberFormatException ignored) {}
}
throw new IllegalArgumentException("durationSeconds must be an integer.");
}
Expand Down
Loading

0 comments on commit caa78c7

Please sign in to comment.