diff --git a/Spring BFF/BFF/src/main/kotlin/com/frontiers/bff/auth/sessions/RedisSessionCleanUpConfig.kt b/Spring BFF/BFF/src/main/kotlin/com/frontiers/bff/auth/sessions/RedisSessionCleanUpConfig.kt index 19e6996..6b7bc28 100644 --- a/Spring BFF/BFF/src/main/kotlin/com/frontiers/bff/auth/sessions/RedisSessionCleanUpConfig.kt +++ b/Spring BFF/BFF/src/main/kotlin/com/frontiers/bff/auth/sessions/RedisSessionCleanUpConfig.kt @@ -6,9 +6,9 @@ import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration import org.springframework.data.domain.Range import org.springframework.data.redis.connection.Limit -import org.springframework.data.redis.connection.ReactiveRedisConnection -import org.springframework.data.redis.connection.ReturnType import org.springframework.data.redis.core.ReactiveRedisOperations +import org.springframework.data.redis.core.ReactiveRedisTemplate +import org.springframework.data.redis.core.ScanOptions import org.springframework.data.redis.core.script.RedisScript import org.springframework.scheduling.annotation.EnableScheduling import org.springframework.scheduling.annotation.Scheduled @@ -16,8 +16,11 @@ import org.springframework.session.config.ReactiveSessionRepositoryCustomizer import org.springframework.session.data.redis.ReactiveRedisIndexedSessionRepository import org.springframework.session.data.redis.config.ConfigureReactiveRedisAction import org.springframework.stereotype.Component +import reactor.core.publisher.Flux import reactor.core.publisher.Mono import reactor.core.scheduler.Schedulers +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.time.Duration import java.time.Instant import java.util.* @@ -63,11 +66,14 @@ internal class RedisCleanUpConfig { @EnableScheduling internal class SessionEvicter( private val redisOperations: ReactiveRedisOperations, - springSessionProperties: SpringSessionProperties, + private val redisTemplate: ReactiveRedisTemplate, + private val springSessionProperties: SpringSessionProperties, ) { - private val redisKeyLocation = springSessionProperties.redis?.expiredSessionsNamespace + private val redisKeyExpirations = springSessionProperties.redis?.expiredSessionsNamespace ?: "spring:session:sessions:expirations" + private val redisKeyNameSpace = springSessionProperties.redis?.sessionNamespace + ?: "spring:session:sessions" companion object { private const val duration : Long = 120 @@ -91,23 +97,25 @@ internal class SessionEvicter( acquireLock(lockValue) .flatMap { acquired -> if (acquired) { - // Lock acquired, perform the cleanup task + // lock acquired, perform the cleanup task performCleanup() - // release lock 10s before duration time - .then(Mono.delay(Duration.ofSeconds(duration - 10))) - .then(releaseLock(lockValue)) - .onErrorResume { e -> - // Handle errors here - logger.error("Error during cleanup or lock release", e) - Mono.empty() - } + // delete orphaned index keys + .then(cleanupOrphanedIndexedKeys()) + // release lock 10s before duration time + .then(Mono.delay(Duration.ofSeconds(duration - 10))) + .then(releaseLock(lockValue)) + .onErrorResume { e -> + // handle errors here + logger.error("Error during cleanup or lock release", e) + Mono.empty() + } } else { - // Lock not acquired, skip cleanup + // lock not acquired, skip cleanup Mono.empty() } } .onErrorResume { e -> - // Handle errors here + // handle errors here logger.error("Error during lock acquisition or cleanup", e) Mono.empty() } @@ -124,7 +132,7 @@ internal class SessionEvicter( listOf(LOCK_KEY), listOf(lockValue, LOCK_EXPIRY.seconds.toString()) ) - .next() // Converts Flux to Mono + .next() // convert Flux to Mono .map { result -> result == "OK" } } @@ -142,7 +150,7 @@ internal class SessionEvicter( listOf(LOCK_KEY), listOf(lockValue) ) - .next() // Converts Flux to Mono + .next() // convert Flux to Mono .map { result -> result == 1L } } @@ -164,17 +172,17 @@ internal class SessionEvicter( logger.info("Time range start: ${Date(context.pastFiveDays.toEpochMilli())}") logger.info("Time range end: ${Date(context.now.toEpochMilli())}") logger.info("Limit count: ${context.limit.count}") - logger.info("Redis key location: $redisKeyLocation") + logger.info("Redis key location: $redisKeyExpirations") } .flatMap { context -> val zSetOps = redisOperations.opsForZSet() - zSetOps.reverseRangeByScore(redisKeyLocation, context.range, context.limit) + zSetOps.reverseRangeByScore(redisKeyExpirations, context.range, context.limit) .collectList() .flatMap { sessionIdsList -> if (sessionIdsList.isNotEmpty()) { logger.info("Found ${sessionIdsList.size} sessions to remove.") zSetOps.remove( - redisKeyLocation, + redisKeyExpirations, *sessionIdsList.toTypedArray() ).doOnSubscribe { logger.info("Started removal of sessions") } .doOnSuccess { logger.info("Successfully removed sessions") } @@ -198,6 +206,91 @@ internal class SessionEvicter( .subscribeOn(Schedulers.boundedElastic()) // to ensure proper threading } + fun cleanupOrphanedIndexedKeys(): Mono { + // find all indexed keys that match the pattern `namespace:sessions:*:idx` + val pattern = "$redisKeyNameSpace:sessions:*:idx" + val scanOptions = ScanOptions.scanOptions().match(pattern).build() + + return redisTemplate.execute { connection -> + val scanPublisher = connection.keyCommands().scan(scanOptions) + + Flux.from(scanPublisher) + // process each ByteBuffer to extract the indexed key + .flatMap { byteBuffer: ByteBuffer -> + val indexedKey = decodeByteBuffer(byteBuffer) + val sessionId = extractSessionIdFromIndexedKey(indexedKey) + val sessionKey = "${springSessionProperties.redis?.sessionNamespace}:sessions:$sessionId" + + // check if the session key exists + redisTemplate.hasKey(sessionKey) + .flatMap { exists -> + if (!exists) { + redisTemplate.opsForSet().members(indexedKey) + .collectList() + .flatMap { members -> + if (members.isNotEmpty()) { + // create a list of removal operations + val removalOps = members.map { member -> + redisTemplate.opsForSet().remove(member.toString(), sessionId) + .then(Mono.fromRunnable { + logger.info("Removed session ID $sessionId from index set $member") + }) + } + // return the removal operations + Mono.just(removalOps) + } else { + Mono.just(emptyList()) + } + } + } else { + Mono.empty() + } + } + } + .collectList() + .flatMap { allRemovalOps -> + // flatten and execute all removal operations in parallel + val allRemovalOpsFlattened = allRemovalOps.flatten() + if (allRemovalOpsFlattened.isNotEmpty()) { + Flux.merge(allRemovalOpsFlattened) + .then(Mono.fromRunnable { + logger.info("All session IDs removed from all indexed keys.") + }) + } else { + Mono.empty() + } + } + .then(Mono.defer { + Flux.from(scanPublisher) + .collectList() + // flatten and execute all removal operations as a batch + .flatMap { byteBuffers -> + val keysToDelete = byteBuffers.map { decodeByteBuffer(it) } + if (keysToDelete.isNotEmpty()) { + redisTemplate.delete(Flux.fromIterable(keysToDelete)) + .doOnSuccess { + logger.info("Deleted orphaned indexed keys: $keysToDelete") + } + .then() + } else { + Mono.empty() + } + } + }) + }.then() + + } + + // function to extract sessionId from the indexed key + private fun extractSessionIdFromIndexedKey(indexedKey: String): String { + // indexed key format: namespace:sessions::idx + return indexedKey.split(":")[5] // extract sessionId + } + + // utility function to decode ByteBuffer to String + private fun decodeByteBuffer(byteBuffer: ByteBuffer): String { + return StandardCharsets.UTF_8.decode(byteBuffer).toString() + } }