Skip to content

Commit

Permalink
Implement logic to validate the token's blog_id property
Browse files Browse the repository at this point in the history
  • Loading branch information
hichamboushaba committed Oct 6, 2023
1 parent ef45c3d commit 8fd387d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ import org.wordpress.android.fluxc.model.JWTToken

class JWTTokenTests {
@Test
fun given_a_valid_token__when_takeIfValid_is_called__then_return_it() {
fun given_a_valid_token__when_validateExpiryDate_is_called__then_return_it() {
val token = generateToken(expired = false)

val result = token.takeIfValid()
val result = token.validateExpiryDate()

Assert.assertNotNull(result)
}

@Test
fun given_an_expired_token__when_takeIfValid_is_called__then_return_null() {
fun given_an_expired_token__when_validateExpiryDate_is_called__then_return_null() {
val token = generateToken(expired = true)

val result = token.takeIfValid()
val result = token.validateExpiryDate()

Assert.assertNull(result)
}
Expand Down
28 changes: 16 additions & 12 deletions fluxc/src/main/java/org/wordpress/android/fluxc/model/JWTToken.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.wordpress.android.fluxc.model

import android.util.Base64
import org.json.JSONException
import org.json.JSONObject

@JvmInline
Expand All @@ -11,21 +10,26 @@ value class JWTToken(
/**
* Returns the token if it is still valid, or null if it is expired.
*/
@Suppress("SwallowedException", "MagicNumber")
fun takeIfValid(): JWTToken? {
fun JSONObject.getLongOrNull(name: String) = try {
this.getLong(name)
} catch (e: JSONException) {
null
}
@Suppress("MagicNumber")
fun validateExpiryDate(): JWTToken? {
fun JSONObject.getLongOrNull(name: String) = this.optLong(name, Long.MAX_VALUE).takeIf { it != Long.MAX_VALUE }

val payload = this.value.split(".")[1]
val claimsJson = String(Base64.decode(payload, Base64.DEFAULT))
val claims = JSONObject(claimsJson)
val payloadJson = getPayloadJson()
val expiration = payloadJson.getLongOrNull("exp")
?: payloadJson.getLongOrNull("expires")
?: return null

val expiration = claims.getLongOrNull("exp") ?: claims.getLongOrNull("expires") ?: return null
val now = System.currentTimeMillis() / 1000

return if (expiration > now) this else null
}

fun getPayloadItem(key: String): String? {
return getPayloadJson().optString(key)
}

private fun getPayloadJson(): JSONObject {
val payloadEncoded = this.value.split(".")[1]
return JSONObject(String(Base64.decode(payloadEncoded, Base64.DEFAULT)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,22 @@ class JetpackAIStore @Inject constructor(
caller = this,
loggedMessage = "fetch Jetpack AI completions"
) {
val token = token?.takeIfValid() ?: fetchJetpackAIJWTToken(site).let { tokenResponse ->
when (tokenResponse) {
is Error -> {
return@withDefaultContext JetpackAICompletionsResponse.Error(
type = AUTH_ERROR,
message = tokenResponse.message,
)
}
val token = token?.validateExpiryDate()?.validateBlogId(site.siteId)
?: fetchJetpackAIJWTToken(site).let { tokenResponse ->
when (tokenResponse) {
is Error -> {
return@withDefaultContext JetpackAICompletionsResponse.Error(
type = AUTH_ERROR,
message = tokenResponse.message,
)
}

is Success -> {
token = tokenResponse.token
tokenResponse.token
is Success -> {
token = tokenResponse.token
tokenResponse.token
}
}
}
}

val result = jetpackAIRestClient.fetchJetpackAITextCompletion(token, prompt, feature)

Expand All @@ -114,4 +115,7 @@ class JetpackAIStore @Inject constructor(
) {
jetpackAIRestClient.fetchJetpackAIJWTToken(site)
}

private fun JWTToken.validateBlogId(blogId: Long): JWTToken? =
if (getPayloadItem("blog_id")?.toLong() == blogId) this else null
}

0 comments on commit 8fd387d

Please sign in to comment.