Skip to content

Commit

Permalink
Merge pull request #2865 from wordpress-mobile/woo/improve-jwt-caching
Browse files Browse the repository at this point in the history
Jetpack AI: improve caching logic of the JWT token
  • Loading branch information
JorgeMucientes authored Oct 10, 2023
2 parents 56bfb5d + 8fd387d commit ba92dd0
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package org.wordpress.android.fluxc

import android.util.Base64
import org.junit.Assert
import org.junit.Test
import org.wordpress.android.fluxc.model.JWTToken

class JWTTokenTests {
@Test
fun given_a_valid_token__when_validateExpiryDate_is_called__then_return_it() {
val token = generateToken(expired = false)

val result = token.validateExpiryDate()

Assert.assertNotNull(result)
}

@Test
fun given_an_expired_token__when_validateExpiryDate_is_called__then_return_null() {
val token = generateToken(expired = true)

val result = token.validateExpiryDate()

Assert.assertNull(result)
}

private fun generateToken(expired: Boolean): JWTToken {
val expirationTime = System.currentTimeMillis() / 1000 + if (expired) -100 else 100

// Sample token from https://jwt.io/ modifier with an expiration time
val header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
val payload = Base64.encode(
"""
{
"sub": "1234567890",
"name": "John Doe",
"iat": 1516239022,
"exp": $expirationTime,
"expires": $expirationTime
}
""".trimIndent().toByteArray(), Base64.DEFAULT
).decodeToString()
val signature = "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"

return JWTToken("$header.$payload.$signature")
}
}
35 changes: 35 additions & 0 deletions fluxc/src/main/java/org/wordpress/android/fluxc/model/JWTToken.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.wordpress.android.fluxc.model

import android.util.Base64
import org.json.JSONObject

@JvmInline
value class JWTToken(
val value: String
) {
/**
* Returns the token if it is still valid, or null if it is expired.
*/
@Suppress("MagicNumber")
fun validateExpiryDate(): JWTToken? {
fun JSONObject.getLongOrNull(name: String) = this.optLong(name, Long.MAX_VALUE).takeIf { it != Long.MAX_VALUE }

val payloadJson = getPayloadJson()
val expiration = payloadJson.getLongOrNull("exp")
?: payloadJson.getLongOrNull("expires")
?: return null

val now = System.currentTimeMillis() / 1000

return if (expiration > now) this else null
}

fun getPayloadItem(key: String): String? {
return getPayloadJson().optString(key)
}

private fun getPayloadJson(): JSONObject {
val payloadEncoded = this.value.split(".")[1]
return JSONObject(String(Base64.decode(payloadEncoded, Base64.DEFAULT)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.android.volley.RequestQueue
import com.google.gson.annotations.SerializedName
import org.wordpress.android.fluxc.Dispatcher
import org.wordpress.android.fluxc.generated.endpoint.WPCOMV2
import org.wordpress.android.fluxc.model.JWTToken
import org.wordpress.android.fluxc.model.SiteModel
import org.wordpress.android.fluxc.network.BaseRequest.GenericErrorType
import org.wordpress.android.fluxc.network.UserAgent
Expand Down Expand Up @@ -43,7 +44,7 @@ class JetpackAIRestClient @Inject constructor(
)

return when (response) {
is Response.Success -> JetpackAIJWTTokenResponse.Success(response.data.token)
is Response.Success -> JetpackAIJWTTokenResponse.Success(JWTToken(response.data.token))
is Response.Error -> JetpackAIJWTTokenResponse.Error(
response.error.toJetpackAICompletionsError(),
response.error.message
Expand All @@ -52,14 +53,14 @@ class JetpackAIRestClient @Inject constructor(
}

suspend fun fetchJetpackAITextCompletion(
token: String,
token: JWTToken,
prompt: String,
feature: String
): JetpackAICompletionsResponse {
val url = WPCOMV2.text_completion.url
val body = mutableMapOf<String, String>()
body.apply {
put("token", token)
put("token", token.value)
put("prompt", prompt)
put("feature", feature)
put("_fields", FIELDS_TO_REQUEST)
Expand Down Expand Up @@ -136,7 +137,7 @@ class JetpackAIRestClient @Inject constructor(
)

sealed class JetpackAIJWTTokenResponse {
data class Success(val token: String) : JetpackAIJWTTokenResponse()
data class Success(val token: JWTToken) : JetpackAIJWTTokenResponse()
data class Error(
val type: JetpackAICompletionsErrorType,
val message: String? = null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.wordpress.android.fluxc.store.jetpackai

import org.wordpress.android.fluxc.model.JWTToken
import org.wordpress.android.fluxc.model.SiteModel
import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient
import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient.JetpackAICompletionsErrorType.AUTH_ERROR
Expand All @@ -8,20 +9,17 @@ import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestCli
import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient.JetpackAIJWTTokenResponse.Error
import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient.JetpackAIJWTTokenResponse.Success
import org.wordpress.android.fluxc.tools.CoroutineEngine
import org.wordpress.android.fluxc.utils.PreferenceUtils.PreferenceUtilsWrapper
import org.wordpress.android.util.AppLog
import javax.inject.Inject
import javax.inject.Singleton

@Singleton
class JetpackAIStore @Inject constructor(
private val jetpackAIRestClient: JetpackAIRestClient,
private val coroutineEngine: CoroutineEngine,
private val preferenceUtils: PreferenceUtilsWrapper
private val coroutineEngine: CoroutineEngine
) {
companion object {
const val JETPACK_AI_JWT_TOKEN_KEY = "JETPACK_AI_JWT_TOKEN_KEY"
}
private var token: JWTToken? = null

/**
* Fetches Jetpack AI completions for a given prompt to be used on a particular post.
*
Expand Down Expand Up @@ -78,59 +76,46 @@ class JetpackAIStore @Inject constructor(
caller = this,
loggedMessage = "fetch Jetpack AI completions"
) {
val token = preferenceUtils.getFluxCPreferences().getString(JETPACK_AI_JWT_TOKEN_KEY, null)
val token = token?.validateExpiryDate()?.validateBlogId(site.siteId)
?: fetchJetpackAIJWTToken(site).let { tokenResponse ->
when (tokenResponse) {
is Error -> {
return@withDefaultContext JetpackAICompletionsResponse.Error(
type = AUTH_ERROR,
message = tokenResponse.message,
)
}

val result = if (token != null) {
jetpackAIRestClient.fetchJetpackAITextCompletion(token, prompt, feature)
} else {
val jwtTokenResponse = fetchJetpackAIJWTToken(site)
fetchCompletionsWithToken(jwtTokenResponse, prompt, feature)
}
is Success -> {
token = tokenResponse.token
tokenResponse.token
}
}
}

val result = jetpackAIRestClient.fetchJetpackAITextCompletion(token, prompt, feature)

return@withDefaultContext when {
// Fetch token anew if using existing token returns AUTH_ERROR
result is JetpackAICompletionsResponse.Error && result.type == AUTH_ERROR -> {
val jwtTokenResponse = fetchJetpackAIJWTToken(site)
fetchCompletionsWithToken(jwtTokenResponse, prompt, feature)
// Remove cached token
this@JetpackAIStore.token = null
fetchJetpackAICompletions(site, prompt, feature)
}

else -> result
}
}

private suspend fun fetchCompletionsWithToken(
jwtTokenResponse: JetpackAIJWTTokenResponse,
prompt: String,
feature: String
): JetpackAICompletionsResponse {
return when (jwtTokenResponse) {
is Error -> {
JetpackAICompletionsResponse.Error(
type = AUTH_ERROR,
message = jwtTokenResponse.message,
)
}

is Success -> {
preferenceUtils.getFluxCPreferences().edit().putString(
JETPACK_AI_JWT_TOKEN_KEY, jwtTokenResponse.token
).apply()

jetpackAIRestClient.fetchJetpackAITextCompletion(
jwtTokenResponse.token,
prompt,
feature
)
}
private suspend fun fetchJetpackAIJWTToken(site: SiteModel): JetpackAIJWTTokenResponse =
coroutineEngine.withDefaultContext(
tag = AppLog.T.API,
caller = this,
loggedMessage = "fetch Jetpack AI JWT token"
) {
jetpackAIRestClient.fetchJetpackAIJWTToken(site)
}
}

private suspend fun fetchJetpackAIJWTToken(site: SiteModel)
: JetpackAIJWTTokenResponse = coroutineEngine.withDefaultContext(
tag = AppLog.T.API,
caller = this,
loggedMessage = "fetch Jetpack AI JWT token"
) {
jetpackAIRestClient.fetchJetpackAIJWTToken(site)
}
private fun JWTToken.validateBlogId(blogId: Long): JWTToken? =
if (getPayloadItem("blog_id")?.toLong() == blogId) this else null
}

0 comments on commit ba92dd0

Please sign in to comment.