Skip to content

Commit

Permalink
Atomic Reference Approach
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Liang <jiallian@amazon.com>
  • Loading branch information
RyanL1997 committed Jul 10, 2023
1 parent fed2d4e commit e8a62ae
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ private void setUserInfoInThreadContext(User user, Set<String> mappedRoles) {
StringJoiner joiner = new StringJoiner("|");
joiner.add(user.getName());
joiner.add(String.join(",", user.getRoles()));
joiner.add(String.join(",", Sets.union(user.getSecurityRoles().stream().collect(ImmutableSet.toImmutableSet()), mappedRoles)));
joiner.add(String.join(",", Sets.union(user.getSecurityRoles(), mappedRoles)));
String requestedTenant = user.getRequestedTenant();
if (!Strings.isNullOrEmpty(requestedTenant)) {
joiner.add(requestedTenant);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern;

import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -145,7 +146,7 @@ private SafeObjectOutputStream(OutputStream out) throws IOException {
@Override
protected Object replaceObject(Object obj) throws IOException {
Class<?> clazz = obj.getClass();
if (isSafeClass(clazz)) {
if (isSafeClass(clazz) || (AtomicReference.class.equals(clazz) && isSafeClass(((AtomicReference)obj).get().getClass()))) {
return obj;
}
throw new IOException("Unauthorized serialization attempt " + clazz.getName());
Expand Down Expand Up @@ -189,7 +190,7 @@ public SafeObjectInputStream(InputStream in) throws IOException {
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {

Class<?> clazz = super.resolveClass(desc);
if (isSafeClass(clazz)) {
if (isSafeClass(clazz) || AtomicReference.class.equals(clazz)) {
return clazz;
}

Expand Down
22 changes: 13 additions & 9 deletions src/main/java/org/opensearch/security/user/User.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import org.opensearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -73,7 +78,7 @@ public class User implements Serializable, Writeable, CustomAttributesAware {
* roles == backend_roles
*/
private final Set<String> roles = Collections.synchronizedSet(new HashSet<String>());
private final Set<String> securityRoles = Collections.synchronizedSet(new HashSet<String>());
private final AtomicReference<Set<String>> securityRoles = new AtomicReference<>(new HashSet<String>());
private String requestedTenant;
private Map<String, String> attributes = Collections.synchronizedMap(new HashMap<>());
private boolean isInjected = false;
Expand All @@ -84,7 +89,7 @@ public User(final StreamInput in) throws IOException {
roles.addAll(in.readList(StreamInput::readString));
requestedTenant = in.readString();
attributes = Collections.synchronizedMap(in.readMap(StreamInput::readString, StreamInput::readString));
securityRoles.addAll(in.readList(StreamInput::readString));
securityRoles.get().addAll(in.readList(StreamInput::readString));
}

/**
Expand Down Expand Up @@ -257,7 +262,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(new ArrayList<String>(roles));
out.writeString(requestedTenant);
out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString);
out.writeStringCollection(securityRoles == null ? Collections.emptyList() : new ArrayList<String>(securityRoles));
out.writeStringCollection(securityRoles.get() ==null?Collections.emptyList():new ArrayList<String>(securityRoles.get()));
}

/**
Expand All @@ -272,15 +277,14 @@ public synchronized final Map<String, String> getCustomAttributesMap() {
return attributes;
}

public synchronized final void addSecurityRoles(final Collection<String> securityRoles) {
public final void addSecurityRoles(final Collection<String> securityRoles) {
if (securityRoles != null && this.securityRoles != null) {
this.securityRoles.addAll(securityRoles);
List<String> filteredRoles = securityRoles.stream().filter(r -> r != null).collect(Collectors.toList());
this.securityRoles.get().addAll(filteredRoles);
}
}

public synchronized final Set<String> getSecurityRoles() {
return this.securityRoles == null
? Collections.synchronizedSet(Collections.emptySet())
: Collections.unmodifiableSet(this.securityRoles);
public final Set<String> getSecurityRoles() {
return this.securityRoles.get() == null ? Collections.synchronizedSet(Collections.emptySet()) : ImmutableSet.copyOf(this.securityRoles.get());
}
}

0 comments on commit e8a62ae

Please sign in to comment.