Skip to content

Commit

Permalink
adjust handling of null dictionary-values in seeder-settings
Browse files Browse the repository at this point in the history
  • Loading branch information
ntruchsess committed Jun 26, 2024
1 parent 6ae3106 commit 329439f
Show file tree
Hide file tree
Showing 18 changed files with 240 additions and 174 deletions.
13 changes: 13 additions & 0 deletions src/framework/Framework.Linq/IfAnyExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,17 @@ public static bool IfAny<T, R>(this IEnumerable<T> source, Func<IEnumerable<T>,
returnValue = null;
return false;
}

public static async ValueTask<bool> IfAnyAwait<T>(this IEnumerable<T> source, Func<IEnumerable<T>, Task> process)
{
var enumerator = source.GetEnumerator();

if (enumerator.MoveNext())
{
await process(new IfAnyEnumerable<T>(source, enumerator)).ConfigureAwait(ConfigureAwaitOptions.None);
return true;
}

return false;
}
}
22 changes: 18 additions & 4 deletions src/framework/Framework.Linq/NullOrSequenceEqualExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ namespace Org.Eclipse.TractusX.Portal.Backend.Framework.Linq;

public static class NullOrSequenceEqualExtensions
{
public static bool NullOrContentEqual<T>(this IEnumerable<T>? items, IEnumerable<T>? others, IEqualityComparer<T>? comparer = null) where T : IComparable =>
public static bool NullOrContentEqual<T>(this IEnumerable<T>? items, IEnumerable<T>? others, IEqualityComparer<T>? comparer = null) where T : IComparable? =>
items == null && others == null ||
items != null && others != null &&
items.OrderBy(x => x).SequenceEqual(others.OrderBy(x => x), comparer);
items.Order().SequenceEqual(others.Order(), comparer);

public static bool NullOrContentEqual<K, V>(this IEnumerable<KeyValuePair<K, V>>? items, IEnumerable<KeyValuePair<K, V>>? others, IEqualityComparer<KeyValuePair<K, V>>? comparer = null) where K : IComparable where V : IComparable =>
public static bool NullOrContentEqual<K, V>(this IEnumerable<KeyValuePair<K, V>>? items, IEnumerable<KeyValuePair<K, V>>? others, IEqualityComparer<KeyValuePair<K, V>>? comparer = null) where K : IComparable where V : IComparable? =>
items == null && others == null ||
items != null && others != null &&
items.OrderBy(x => x.Key).SequenceEqual(others.OrderBy(x => x.Key), comparer);
Expand All @@ -37,6 +37,11 @@ public static bool NullOrContentEqual<K, V>(this IEnumerable<KeyValuePair<K, IEn
items == null && others == null ||
items != null && others != null &&
items.OrderBy(x => x.Key).SequenceEqual(others.OrderBy(x => x.Key), comparer ?? new EnumerableValueKeyValuePairEqualityComparer<K, V>());

public static bool NullOrNullableContentEqual<K, V>(this IEnumerable<KeyValuePair<K, IEnumerable<V>?>>? items, IEnumerable<KeyValuePair<K, IEnumerable<V>?>>? others, IEqualityComparer<KeyValuePair<K, IEnumerable<V>?>>? comparer = null) where V : IComparable =>
items == null && others == null ||
items != null && others != null &&
items.OrderBy(x => x.Key).SequenceEqual(others.OrderBy(x => x.Key), comparer ?? new NullableEnumerableValueKeyValuePairEqualityComparer<K, V>());
}

public class KeyValuePairEqualityComparer<K, V> : IEqualityComparer<KeyValuePair<K, V>> where K : IComparable where V : IComparable
Expand All @@ -50,7 +55,16 @@ public class EnumerableValueKeyValuePairEqualityComparer<K, V> : IEqualityCompar
{
public bool Equals(KeyValuePair<K, IEnumerable<V>> source, KeyValuePair<K, IEnumerable<V>> other) =>
Equals(source.Key, other.Key) &&
source.Value.NullOrContentEqual(other.Value);
source.Value.Order().SequenceEqual(other.Value.Order());

public int GetHashCode([DisallowNull] KeyValuePair<K, IEnumerable<V>> obj) => throw new NotImplementedException();
}

public class NullableEnumerableValueKeyValuePairEqualityComparer<K, V> : IEqualityComparer<KeyValuePair<K, IEnumerable<V>?>> where V : IComparable
{
public bool Equals(KeyValuePair<K, IEnumerable<V>?> source, KeyValuePair<K, IEnumerable<V>?> other) =>
Equals(source.Key, other.Key) &&
source.Value.NullOrContentEqual(other.Value);

public int GetHashCode([DisallowNull] KeyValuePair<K, IEnumerable<V>?> obj) => throw new NotImplementedException();
}
20 changes: 20 additions & 0 deletions src/framework/Framework.Linq/NullableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,24 @@ public static class NullableExtensions
{
public static bool IsNullOrEmpty<T>(this IEnumerable<T>? collection) =>
collection == null || !collection.Any();

public static IEnumerable<KeyValuePair<TKey, TValue>> FilterNotNullValues<TKey, TValue>(this IEnumerable<KeyValuePair<TKey, TValue?>> source)
{
foreach (var (key, value) in source)
{
if (value == null)
continue;
yield return KeyValuePair.Create(key, value);
}
}

public static IEnumerable<T> FilterNotNull<T>(this IEnumerable<T?> source)
{
foreach (var item in source)
{
if (item == null)
continue;
yield return item;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public AuthenticationFlowHandler(KeycloakClient keycloak, ISeedDataHandler seedD

public async Task UpdateAuthenticationFlows(CancellationToken cancellationToken)
{
var flows = (await _keycloak.GetAuthenticationFlowsAsync(_realm, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None));
var flows = await _keycloak.GetAuthenticationFlowsAsync(_realm, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None);
var seedFlows = _seedData.TopLevelCustomAuthenticationFlows;
var topLevelCustomFlows = flows.Where(flow => !(flow.BuiltIn ?? false) && (flow.TopLevel ?? false));

Expand Down Expand Up @@ -275,7 +275,7 @@ await _keycloak.CreateAuthenticationExecutionConfigurationAsync(
new AuthenticatorConfig
{
Alias = update.AuthenticatorConfig,
Config = _seedData.GetAuthenticatorConfig(update.AuthenticatorConfig).Config?.ToDictionary(x => x.Key, x => x.Value)
Config = _seedData.GetAuthenticatorConfig(update.AuthenticatorConfig).Config?.FilterNotNullValues()?.ToDictionary()
},
cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None);
break;
Expand All @@ -292,7 +292,7 @@ await _keycloak.DeleteAuthenticatorConfigurationAsync(
if (config == null)
throw new UnexpectedConditionException("authenticatorConfig is null");
config.Alias = update.AuthenticatorConfig;
config.Config = updateConfig.Config?.ToDictionary(x => x.Key, x => x.Value);
config.Config = updateConfig.Config?.FilterNotNullValues()?.ToDictionary();
await _keycloak.UpdateAuthenticatorConfigurationAsync(_realm, execution.AuthenticationConfig, config, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None);
break;
}
Expand Down Expand Up @@ -324,7 +324,7 @@ private bool CompareFlowExecutions(AuthenticationFlowExecution execution, Authen

private static bool CompareAuthenticatorConfig(AuthenticatorConfig config, AuthenticatorConfigModel update) =>
config.Alias == update.Alias &&
config.Config.NullOrContentEqual(update.Config);
config.Config.NullOrContentEqual(update.Config?.FilterNotNullValues());

private Task<IEnumerable<AuthenticationFlowExecution>> GetExecutions(string alias, CancellationToken cancellationToken) =>
_keycloak.GetAuthenticationFlowExecutionsAsync(_realm, alias, cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
********************************************************************************/

using Org.Eclipse.TractusX.Portal.Backend.Framework.ErrorHandling;
using Org.Eclipse.TractusX.Portal.Backend.Framework.Linq;
using Org.Eclipse.TractusX.Portal.Backend.Keycloak.Factory;
using Org.Eclipse.TractusX.Portal.Backend.Keycloak.Library;
using Org.Eclipse.TractusX.Portal.Backend.Keycloak.Library.Models.Roles;
Expand Down Expand Up @@ -58,24 +59,18 @@ public async Task UpdateClientScopeMapper(string instanceName, CancellationToken
throw new ConflictException($"No client id found with name {clientName}");
}
var clientRoles = await keycloak.GetClientRolesScopeMappingsForClientAsync(realm, clientScope.Id, client.Id, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None);
var mappingModelRoles = mappingModel.Roles.Select(roleName => roles.SingleOrDefault(r => r.Name == roleName) ?? throw new ConflictException($"No role with name {roleName} found"));
var mappingModelRoles = mappingModel.Roles?.Select(roleName => roles.SingleOrDefault(r => r.Name == roleName) ?? throw new ConflictException($"No role with name {roleName} found")) ?? Enumerable.Empty<Role>();
await AddAndDeleteRoles(keycloak, realm, clientScope.Id, client.Id, clientRoles, mappingModelRoles, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None);
}
}
}

private static async Task AddAndDeleteRoles(KeycloakClient keycloak, string realm, string clientScopeId, string clientId, IEnumerable<Role> roles, IEnumerable<Role> updateRoles, CancellationToken cancellationToken)
{
var rolesToAdd = updateRoles.ExceptBy(roles.Select(role => role.Name), roleModel => roleModel.Name).ToList();
var rolesToDelete = roles.ExceptBy(updateRoles.Select(roleModel => roleModel.Name), role => role.Name).ToList();
if (rolesToDelete.Any())
{
await keycloak.RemoveClientRolesFromClientScopeForClientAsync(realm, clientScopeId, clientId, rolesToDelete, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None);
}
await updateRoles.ExceptBy(roles.Select(role => role.Name), roleModel => roleModel.Name).IfAnyAwait(rolesToAdd =>
keycloak.AddClientRolesScopeMappingToClientAsync(realm, clientScopeId, clientId, rolesToAdd, cancellationToken)).ConfigureAwait(false);

if (rolesToAdd.Any())
{
await keycloak.AddClientRolesScopeMappingToClientAsync(realm, clientScopeId, clientId, rolesToAdd, cancellationToken).ConfigureAwait(ConfigureAwaitOptions.None);
}
await roles.ExceptBy(updateRoles.Select(roleModel => roleModel.Name), role => role.Name).IfAnyAwait(rolesToDelete =>
keycloak.RemoveClientRolesFromClientScopeForClientAsync(realm, clientScopeId, clientId, rolesToDelete, cancellationToken)).ConfigureAwait(false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,15 @@ private static bool CompareClientScope(ClientScope scope, ClientScopeModel updat
scope.Attributes != null && update.Attributes != null &&
CompareClientScopeAttributes(scope.Attributes, update.Attributes));

private static Attributes CreateClientScopeAttributes(IReadOnlyDictionary<string, string> update) =>
new Attributes
private static Attributes CreateClientScopeAttributes(IReadOnlyDictionary<string, string?> update) =>
new()
{
ConsentScreenText = update.GetValueOrDefault("consent.screen.text"),
DisplayOnConsentScreen = update.GetValueOrDefault("display.on.consent.screen"),
IncludeInTokenScope = update.GetValueOrDefault("include.in.token.scope")
};

private static bool CompareClientScopeAttributes(Attributes attributes, IReadOnlyDictionary<string, string> update) =>
private static bool CompareClientScopeAttributes(Attributes attributes, IReadOnlyDictionary<string, string?> update) =>
attributes.ConsentScreenText == update.GetValueOrDefault("consent.screen.text") &&
attributes.DisplayOnConsentScreen == update.GetValueOrDefault("display.on.consent.screen") &&
attributes.IncludeInTokenScope == update.GetValueOrDefault("include.in.token.scope");
Expand Down
10 changes: 5 additions & 5 deletions src/keycloak/Keycloak.Seeding/BusinessLogic/ClientsUpdater.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ await keycloak.UpdateClientProtocolMapperAsync(
PublicClient = update.PublicClient,
FrontChannelLogout = update.FrontchannelLogout,
Protocol = update.Protocol,
Attributes = update.Attributes?.ToDictionary(x => x.Key, x => x.Value),
AuthenticationFlowBindingOverrides = update.AuthenticationFlowBindingOverrides?.ToDictionary(x => x.Key, x => x.Value),
Attributes = update.Attributes?.FilterNotNullValues()?.ToDictionary(),
AuthenticationFlowBindingOverrides = update.AuthenticationFlowBindingOverrides?.FilterNotNullValues()?.ToDictionary(),
FullScopeAllowed = update.FullScopeAllowed,
NodeReregistrationTimeout = update.NodeReRegistrationTimeout,
DefaultClientScopes = update.DefaultClientScopes,
Expand Down Expand Up @@ -202,8 +202,8 @@ private static bool CompareClient(Client client, ClientModel update) =>
client.PublicClient == update.PublicClient &&
client.FrontChannelLogout == update.FrontchannelLogout &&
client.Protocol == update.Protocol &&
client.Attributes.NullOrContentEqual(update.Attributes) &&
client.AuthenticationFlowBindingOverrides.NullOrContentEqual(update.AuthenticationFlowBindingOverrides) &&
client.Attributes.NullOrContentEqual(update.Attributes?.FilterNotNullValues()) &&
client.AuthenticationFlowBindingOverrides.NullOrContentEqual(update.AuthenticationFlowBindingOverrides?.FilterNotNullValues()) &&
client.FullScopeAllowed == update.FullScopeAllowed &&
client.NodeReregistrationTimeout == update.NodeReRegistrationTimeout &&
client.DefaultClientScopes.NullOrContentEqual(update.DefaultClientScopes) &&
Expand All @@ -228,7 +228,7 @@ private static bool CompareClientProtocolMapper(ClientProtocolMapper mapper, Pro
mapper.Config != null && update.Config != null &&
CompareClientProtocolMapperConfig(mapper.Config, update.Config));

private static bool CompareClientProtocolMapperConfig(ClientConfig config, IReadOnlyDictionary<string, string> update) =>
private static bool CompareClientProtocolMapperConfig(ClientConfig config, IReadOnlyDictionary<string, string?> update) =>
config.UserInfoTokenClaim == update.GetValueOrDefault("userinfo.token.claim") &&
config.UserAttribute == update.GetValueOrDefault("user.attribute") &&
config.IdTokenClaim == update.GetValueOrDefault("id.token.claim") &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface ISeedDataHandler

IEnumerable<ClientModel> Clients { get; }

IReadOnlyDictionary<string, IEnumerable<RoleModel>> ClientRoles { get; }
IEnumerable<(string ClientId, IEnumerable<RoleModel> RoleModels)> ClientRoles { get; }

IEnumerable<RoleModel> RealmRoles { get; }

Expand All @@ -48,7 +48,7 @@ public interface ISeedDataHandler

IReadOnlyDictionary<string, string> ClientsDictionary { get; }

IReadOnlyDictionary<string, IEnumerable<ClientScopeMappingModel>> ClientScopeMappings { get; }
IEnumerable<(string ClientId, IEnumerable<ClientScopeMappingModel> ClientScopeMappingModels)> ClientScopeMappings { get; }

Task SetClientInternalIds(IAsyncEnumerable<(string ClientId, string Id)> clientInternalIds);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/********************************************************************************
* Copyright (c) 2023 BMW Group AG
* Copyright (c) 2023 Contributors to the Eclipse Foundation
*
* See the NOTICE file(s) distributed with this work for additional
Expand Down Expand Up @@ -245,12 +244,12 @@ private static bool CompareIdentityProviderConfig(Config? config, IdentityProvid
private static IdentityProviderMapper UpdateIdentityProviderMapper(IdentityProviderMapper mapper, IdentityProviderMapperModel updateMapper)
{
mapper._IdentityProviderMapper = updateMapper.IdentityProviderMapper;
mapper.Config = updateMapper.Config?.ToDictionary(x => x.Key, x => x.Value);
mapper.Config = updateMapper.Config?.FilterNotNullValues()?.ToDictionary();
return mapper;
}

private static bool CompareIdentityProviderMapper(IdentityProviderMapper mapper, IdentityProviderMapperModel updateMapper) =>
mapper.IdentityProviderAlias == updateMapper.IdentityProviderAlias &&
mapper._IdentityProviderMapper == updateMapper.IdentityProviderMapper &&
mapper.Config.NullOrContentEqual(updateMapper.Config);
mapper.Config.NullOrContentEqual(updateMapper.Config?.FilterNotNullValues());
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/********************************************************************************
* Copyright (c) 2023 BMW Group AG
* Copyright (c) 2023 Contributors to the Eclipse Foundation
*
* See the NOTICE file(s) distributed with this work for additional
Expand Down Expand Up @@ -44,7 +43,7 @@ public static bool CompareProtocolMapper(ProtocolMapper mapper, ProtocolMapperMo
mapper.Config != null && update.Config != null &&
CompareProtocolMapperConfig(mapper.Config, update.Config));

private static Config CreateProtocolMapperConfig(IReadOnlyDictionary<string, string> update) =>
private static Config CreateProtocolMapperConfig(IReadOnlyDictionary<string, string?> update) =>
new Config
{
Single = update.GetValueOrDefault("single"),
Expand All @@ -67,7 +66,7 @@ private static Config CreateProtocolMapperConfig(IReadOnlyDictionary<string, str
UserSessionNote = update.GetValueOrDefault("user.session.note"),
};

private static bool CompareProtocolMapperConfig(Config config, IReadOnlyDictionary<string, string> update) =>
private static bool CompareProtocolMapperConfig(Config config, IReadOnlyDictionary<string, string?> update) =>
config.Single == update.GetValueOrDefault("single") &&
config.AttributeNameFormat == update.GetValueOrDefault("attribute.nameformat") &&
config.AttributeName == update.GetValueOrDefault("attribute.name") &&
Expand Down
9 changes: 3 additions & 6 deletions src/keycloak/Keycloak.Seeding/BusinessLogic/RealmUpdater.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/********************************************************************************
* Copyright (c) 2023 BMW Group AG
* Copyright (c) 2023 Contributors to the Eclipse Foundation
*
* See the NOTICE file(s) distributed with this work for additional
Expand Down Expand Up @@ -126,7 +125,7 @@ public async Task UpdateRealm(string keycloakInstanceName, CancellationToken can
keycloakRealm.ResetCredentialsFlow = seedRealm.ResetCredentialsFlow;
keycloakRealm.ClientAuthenticationFlow = seedRealm.ClientAuthenticationFlow;
keycloakRealm.DockerAuthenticationFlow = seedRealm.DockerAuthenticationFlow;
keycloakRealm.Attributes = seedRealm.Attributes?.ToDictionary(x => x.Key, x => x.Value);
keycloakRealm.Attributes = seedRealm.Attributes?.FilterNotNullValues()?.ToDictionary();
keycloakRealm.UserManagedAccessAllowed = seedRealm.UserManagedAccessAllowed;
keycloakRealm.PasswordPolicy = seedRealm.PasswordPolicy;

Expand Down Expand Up @@ -205,10 +204,8 @@ private static bool CompareRealm(Realm keycloakRealm, KeycloakRealm seedRealm) =
keycloakRealm.UserManagedAccessAllowed == seedRealm.UserManagedAccessAllowed &&
keycloakRealm.PasswordPolicy == seedRealm.PasswordPolicy;

private static bool CompareRealmAttributes(IDictionary<string, string>? attributes, IReadOnlyDictionary<string, string>? updateAttributes) =>
attributes == null && updateAttributes == null ||
attributes != null && updateAttributes != null &&
attributes.OrderBy(x => x.Key).SequenceEqual(updateAttributes.OrderBy(x => x.Key));
private static bool CompareRealmAttributes(IEnumerable<KeyValuePair<string, string>>? attributes, IEnumerable<KeyValuePair<string, string?>>? updateAttributes) =>
attributes.NullOrContentEqual(updateAttributes?.FilterNotNullValues());

private static bool CompareBrowserSecurityHeaders(BrowserSecurityHeaders? securityHeaders, BrowserSecurityHeadersModel? updateSecurityHeaders) =>
securityHeaders == null && updateSecurityHeaders == null ||
Expand Down
Loading

0 comments on commit 329439f

Please sign in to comment.