Skip to content

Commit

Permalink
CID-2776: refresh tokens if expired
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamedlajmileanix committed Aug 5, 2024
1 parent c5715dd commit 158679c
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package net.leanix.githubagent.services

import io.jsonwebtoken.Jwts
import io.jsonwebtoken.SignatureAlgorithm
import net.leanix.githubagent.client.GitHubClient
import net.leanix.githubagent.config.GitHubEnterpriseProperties
import net.leanix.githubagent.dto.Installation
import net.leanix.githubagent.exceptions.FailedToCreateJWTException
import net.leanix.githubagent.exceptions.JwtTokenNotFound
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.slf4j.LoggerFactory
import org.springframework.core.io.ResourceLoader
Expand All @@ -24,7 +27,8 @@ class GitHubAuthenticationService(
private val cachingService: CachingService,
private val githubEnterpriseProperties: GitHubEnterpriseProperties,
private val resourceLoader: ResourceLoader,
private val gitHubEnterpriseService: GitHubEnterpriseService
private val gitHubEnterpriseService: GitHubEnterpriseService,
private val gitHubClient: GitHubClient,
) {

companion object {
Expand All @@ -34,6 +38,15 @@ class GitHubAuthenticationService(
private val logger = LoggerFactory.getLogger(GitHubAuthenticationService::class.java)
}

fun refreshTokens() {
generateJwtToken()
val jwtToken = cachingService.get("jwtToken") ?: throw JwtTokenNotFound()
generateAndCacheInstallationTokens(
gitHubClient.getInstallations("Bearer $jwtToken"),
jwtToken.toString()
)
}

fun generateJwtToken() {
runCatching {
logger.info("Generating JWT token")
Expand Down Expand Up @@ -67,6 +80,16 @@ class GitHubAuthenticationService(
}
}

fun generateAndCacheInstallationTokens(
installations: List<Installation>,
jwtToken: String
) {
installations.forEach { installation ->
val installationToken = gitHubClient.createInstallationToken(installation.id, "Bearer $jwtToken").token
cachingService.set("installationToken:${installation.id}", installationToken, 3600L)
}
}

@Throws(IOException::class)
private fun readPrivateKey(): String {
val pemFile = File(resourceLoader.getResource("file:${githubEnterpriseProperties.pemFile}").uri)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class GitHubScanningService(
private val gitHubClient: GitHubClient,
private val cachingService: CachingService,
private val webSocketService: WebSocketService,
private val gitHubGraphQLService: GitHubGraphQLService
private val gitHubGraphQLService: GitHubGraphQLService,
private val gitHubAuthenticationService: GitHubAuthenticationService
) {

private val logger = LoggerFactory.getLogger(GitHubScanningService::class.java)
Expand All @@ -37,20 +38,10 @@ class GitHubScanningService(

private fun getInstallations(jwtToken: String): List<Installation> {
val installations = gitHubClient.getInstallations("Bearer $jwtToken")
generateAndCacheInstallationTokens(installations, jwtToken)
gitHubAuthenticationService.generateAndCacheInstallationTokens(installations, jwtToken)
return installations
}

private fun generateAndCacheInstallationTokens(
installations: List<Installation>,
jwtToken: String
) {
installations.forEach { installation ->
val installationToken = gitHubClient.createInstallationToken(installation.id, "Bearer $jwtToken").token
cachingService.set("installationToken:${installation.id}", installationToken, 3600L)
}
}

private fun fetchAndSendOrganisationsData(
installations: List<Installation>
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class WebhookService(
private val webSocketService: WebSocketService,
private val gitHubGraphQLService: GitHubGraphQLService,
private val gitHubEnterpriseProperties: GitHubEnterpriseProperties,
private val cachingService: CachingService
private val cachingService: CachingService,
private val gitHubAuthenticationService: GitHubAuthenticationService
) {

private val logger = LoggerFactory.getLogger(WebhookService::class.java)
Expand All @@ -38,9 +39,12 @@ class WebhookService(
val headCommit = pushEventPayload.headCommit
val organizationName = pushEventPayload.repository.owner.name

val installationToken = cachingService.get("installationToken:${pushEventPayload.installation.id}")?.toString()
?: throw IllegalArgumentException("Installation token not found/ expired")
// TODO refresh token if expired
var installationToken = cachingService.get("installationToken:${pushEventPayload.installation.id}")?.toString()
if (installationToken == null) {
gitHubAuthenticationService.refreshTokens()
installationToken = cachingService.get("installationToken:${pushEventPayload.installation.id}")?.toString()
require(installationToken != null) { "Installation token not found/ expired" }
}

if (pushEventPayload.ref == "refs/heads/${pushEventPayload.repository.defaultBranch}") {
when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package net.leanix.githubagent.services

import io.mockk.every
import io.mockk.mockk
import net.leanix.githubagent.client.GitHubClient
import net.leanix.githubagent.config.GitHubEnterpriseProperties
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.Assertions.assertThrows
Expand All @@ -16,11 +17,13 @@ class GitHubAuthenticationServiceTest {
private val githubEnterpriseProperties = mockk<GitHubEnterpriseProperties>()
private val resourceLoader = mockk<ResourceLoader>()
private val gitHubEnterpriseService = mockk<GitHubEnterpriseService>()
private val gitHubClient = mockk<GitHubClient>()
private val githubAuthenticationService = GitHubAuthenticationService(
cachingService,
githubEnterpriseProperties,
resourceLoader,
gitHubEnterpriseService
gitHubEnterpriseService,
gitHubClient
)

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ class GitHubScanningServiceTest {
private val cachingService = mockk<CachingService>()
private val webSocketService = mockk<WebSocketService>(relaxUnitFun = true)
private val gitHubGraphQLService = mockk<GitHubGraphQLService>()
private val gitHubAuthenticationService = mockk<GitHubAuthenticationService>()
private val gitHubScanningService = GitHubScanningService(
gitHubClient,
cachingService,
webSocketService,
gitHubGraphQLService
gitHubGraphQLService,
gitHubAuthenticationService
)
private val runId = UUID.randomUUID()

Expand All @@ -47,6 +49,7 @@ class GitHubScanningServiceTest {
cursor = null
)
every { cachingService.remove(any()) } returns Unit
every { gitHubAuthenticationService.generateAndCacheInstallationTokens(any(), any()) } returns Unit
}

@Test
Expand Down

0 comments on commit 158679c

Please sign in to comment.