diff --git a/example/src/androidTest/java/org/wordpress/android/fluxc/JWTTokenTests.kt b/example/src/androidTest/java/org/wordpress/android/fluxc/JWTTokenTests.kt new file mode 100644 index 0000000000..162592a8cf --- /dev/null +++ b/example/src/androidTest/java/org/wordpress/android/fluxc/JWTTokenTests.kt @@ -0,0 +1,47 @@ +package org.wordpress.android.fluxc + +import android.util.Base64 +import org.junit.Assert +import org.junit.Test +import org.wordpress.android.fluxc.model.JWTToken + +class JWTTokenTests { + @Test + fun given_a_valid_token__when_validateExpiryDate_is_called__then_return_it() { + val token = generateToken(expired = false) + + val result = token.validateExpiryDate() + + Assert.assertNotNull(result) + } + + @Test + fun given_an_expired_token__when_validateExpiryDate_is_called__then_return_null() { + val token = generateToken(expired = true) + + val result = token.validateExpiryDate() + + Assert.assertNull(result) + } + + private fun generateToken(expired: Boolean): JWTToken { + val expirationTime = System.currentTimeMillis() / 1000 + if (expired) -100 else 100 + + // Sample token from https://jwt.io/ modifier with an expiration time + val header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" + val payload = Base64.encode( + """ + { + "sub": "1234567890", + "name": "John Doe", + "iat": 1516239022, + "exp": $expirationTime, + "expires": $expirationTime + } + """.trimIndent().toByteArray(), Base64.DEFAULT + ).decodeToString() + val signature = "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + + return JWTToken("$header.$payload.$signature") + } +} diff --git a/fluxc/src/main/java/org/wordpress/android/fluxc/model/JWTToken.kt b/fluxc/src/main/java/org/wordpress/android/fluxc/model/JWTToken.kt new file mode 100644 index 0000000000..5b1ba89750 --- /dev/null +++ b/fluxc/src/main/java/org/wordpress/android/fluxc/model/JWTToken.kt @@ -0,0 +1,35 @@ +package org.wordpress.android.fluxc.model + +import android.util.Base64 +import org.json.JSONObject + +@JvmInline +value class JWTToken( + val value: String +) { + /** + * Returns the token if it is still valid, or null if it is expired. + */ + @Suppress("MagicNumber") + fun validateExpiryDate(): JWTToken? { + fun JSONObject.getLongOrNull(name: String) = this.optLong(name, Long.MAX_VALUE).takeIf { it != Long.MAX_VALUE } + + val payloadJson = getPayloadJson() + val expiration = payloadJson.getLongOrNull("exp") + ?: payloadJson.getLongOrNull("expires") + ?: return null + + val now = System.currentTimeMillis() / 1000 + + return if (expiration > now) this else null + } + + fun getPayloadItem(key: String): String? { + return getPayloadJson().optString(key) + } + + private fun getPayloadJson(): JSONObject { + val payloadEncoded = this.value.split(".")[1] + return JSONObject(String(Base64.decode(payloadEncoded, Base64.DEFAULT))) + } +} diff --git a/fluxc/src/main/java/org/wordpress/android/fluxc/network/rest/wpcom/jetpackai/JetpackAIRestClient.kt b/fluxc/src/main/java/org/wordpress/android/fluxc/network/rest/wpcom/jetpackai/JetpackAIRestClient.kt index acfb8148f7..48bd479d29 100644 --- a/fluxc/src/main/java/org/wordpress/android/fluxc/network/rest/wpcom/jetpackai/JetpackAIRestClient.kt +++ b/fluxc/src/main/java/org/wordpress/android/fluxc/network/rest/wpcom/jetpackai/JetpackAIRestClient.kt @@ -5,6 +5,7 @@ import com.android.volley.RequestQueue import com.google.gson.annotations.SerializedName import org.wordpress.android.fluxc.Dispatcher import org.wordpress.android.fluxc.generated.endpoint.WPCOMV2 +import org.wordpress.android.fluxc.model.JWTToken import org.wordpress.android.fluxc.model.SiteModel import org.wordpress.android.fluxc.network.BaseRequest.GenericErrorType import org.wordpress.android.fluxc.network.UserAgent @@ -43,7 +44,7 @@ class JetpackAIRestClient @Inject constructor( ) return when (response) { - is Response.Success -> JetpackAIJWTTokenResponse.Success(response.data.token) + is Response.Success -> JetpackAIJWTTokenResponse.Success(JWTToken(response.data.token)) is Response.Error -> JetpackAIJWTTokenResponse.Error( response.error.toJetpackAICompletionsError(), response.error.message @@ -52,14 +53,14 @@ class JetpackAIRestClient @Inject constructor( } suspend fun fetchJetpackAITextCompletion( - token: String, + token: JWTToken, prompt: String, feature: String ): JetpackAICompletionsResponse { val url = WPCOMV2.text_completion.url val body = mutableMapOf() body.apply { - put("token", token) + put("token", token.value) put("prompt", prompt) put("feature", feature) put("_fields", FIELDS_TO_REQUEST) @@ -136,7 +137,7 @@ class JetpackAIRestClient @Inject constructor( ) sealed class JetpackAIJWTTokenResponse { - data class Success(val token: String) : JetpackAIJWTTokenResponse() + data class Success(val token: JWTToken) : JetpackAIJWTTokenResponse() data class Error( val type: JetpackAICompletionsErrorType, val message: String? = null diff --git a/fluxc/src/main/java/org/wordpress/android/fluxc/store/jetpackai/JetpackAIStore.kt b/fluxc/src/main/java/org/wordpress/android/fluxc/store/jetpackai/JetpackAIStore.kt index 577df73fed..82ae4344d3 100644 --- a/fluxc/src/main/java/org/wordpress/android/fluxc/store/jetpackai/JetpackAIStore.kt +++ b/fluxc/src/main/java/org/wordpress/android/fluxc/store/jetpackai/JetpackAIStore.kt @@ -1,5 +1,6 @@ package org.wordpress.android.fluxc.store.jetpackai +import org.wordpress.android.fluxc.model.JWTToken import org.wordpress.android.fluxc.model.SiteModel import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient.JetpackAICompletionsErrorType.AUTH_ERROR @@ -8,7 +9,6 @@ import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestCli import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient.JetpackAIJWTTokenResponse.Error import org.wordpress.android.fluxc.network.rest.wpcom.jetpackai.JetpackAIRestClient.JetpackAIJWTTokenResponse.Success import org.wordpress.android.fluxc.tools.CoroutineEngine -import org.wordpress.android.fluxc.utils.PreferenceUtils.PreferenceUtilsWrapper import org.wordpress.android.util.AppLog import javax.inject.Inject import javax.inject.Singleton @@ -16,12 +16,10 @@ import javax.inject.Singleton @Singleton class JetpackAIStore @Inject constructor( private val jetpackAIRestClient: JetpackAIRestClient, - private val coroutineEngine: CoroutineEngine, - private val preferenceUtils: PreferenceUtilsWrapper + private val coroutineEngine: CoroutineEngine ) { - companion object { - const val JETPACK_AI_JWT_TOKEN_KEY = "JETPACK_AI_JWT_TOKEN_KEY" - } + private var token: JWTToken? = null + /** * Fetches Jetpack AI completions for a given prompt to be used on a particular post. * @@ -78,59 +76,46 @@ class JetpackAIStore @Inject constructor( caller = this, loggedMessage = "fetch Jetpack AI completions" ) { - val token = preferenceUtils.getFluxCPreferences().getString(JETPACK_AI_JWT_TOKEN_KEY, null) + val token = token?.validateExpiryDate()?.validateBlogId(site.siteId) + ?: fetchJetpackAIJWTToken(site).let { tokenResponse -> + when (tokenResponse) { + is Error -> { + return@withDefaultContext JetpackAICompletionsResponse.Error( + type = AUTH_ERROR, + message = tokenResponse.message, + ) + } - val result = if (token != null) { - jetpackAIRestClient.fetchJetpackAITextCompletion(token, prompt, feature) - } else { - val jwtTokenResponse = fetchJetpackAIJWTToken(site) - fetchCompletionsWithToken(jwtTokenResponse, prompt, feature) - } + is Success -> { + token = tokenResponse.token + tokenResponse.token + } + } + } + + val result = jetpackAIRestClient.fetchJetpackAITextCompletion(token, prompt, feature) return@withDefaultContext when { // Fetch token anew if using existing token returns AUTH_ERROR result is JetpackAICompletionsResponse.Error && result.type == AUTH_ERROR -> { - val jwtTokenResponse = fetchJetpackAIJWTToken(site) - fetchCompletionsWithToken(jwtTokenResponse, prompt, feature) + // Remove cached token + this@JetpackAIStore.token = null + fetchJetpackAICompletions(site, prompt, feature) } else -> result } } - private suspend fun fetchCompletionsWithToken( - jwtTokenResponse: JetpackAIJWTTokenResponse, - prompt: String, - feature: String - ): JetpackAICompletionsResponse { - return when (jwtTokenResponse) { - is Error -> { - JetpackAICompletionsResponse.Error( - type = AUTH_ERROR, - message = jwtTokenResponse.message, - ) - } - - is Success -> { - preferenceUtils.getFluxCPreferences().edit().putString( - JETPACK_AI_JWT_TOKEN_KEY, jwtTokenResponse.token - ).apply() - - jetpackAIRestClient.fetchJetpackAITextCompletion( - jwtTokenResponse.token, - prompt, - feature - ) - } + private suspend fun fetchJetpackAIJWTToken(site: SiteModel): JetpackAIJWTTokenResponse = + coroutineEngine.withDefaultContext( + tag = AppLog.T.API, + caller = this, + loggedMessage = "fetch Jetpack AI JWT token" + ) { + jetpackAIRestClient.fetchJetpackAIJWTToken(site) } - } - private suspend fun fetchJetpackAIJWTToken(site: SiteModel) - : JetpackAIJWTTokenResponse = coroutineEngine.withDefaultContext( - tag = AppLog.T.API, - caller = this, - loggedMessage = "fetch Jetpack AI JWT token" - ) { - jetpackAIRestClient.fetchJetpackAIJWTToken(site) - } + private fun JWTToken.validateBlogId(blogId: Long): JWTToken? = + if (getPayloadItem("blog_id")?.toLong() == blogId) this else null }