From 043ae51c924c3dc0ba6fedbc7d0d53e89ae81921 Mon Sep 17 00:00:00 2001 From: Nikita Voloshin Date: Thu, 21 Apr 2022 22:33:24 +0300 Subject: [PATCH] Fix Rasa response deserialization --- activators/rasa/build.gradle.kts | 2 + .../activator/rasa/RasaIntentActivator.kt | 4 +- .../jaicf/activator/rasa/api/RasaApi.kt | 14 +- .../rasa/api/RasaParseMessageResponse.kt | 8 +- .../rasa/test/RasaDeserializationTest.kt | 132 ++++++++++++++++++ 5 files changed, 149 insertions(+), 11 deletions(-) create mode 100644 activators/rasa/src/test/kotlin/com/justai/jaicf/activator/rasa/test/RasaDeserializationTest.kt diff --git a/activators/rasa/build.gradle.kts b/activators/rasa/build.gradle.kts index cb4ef794..71bae12c 100644 --- a/activators/rasa/build.gradle.kts +++ b/activators/rasa/build.gradle.kts @@ -16,4 +16,6 @@ dependencies { api(ktor("ktor-client-cio")) api(ktor("ktor-client-serialization-jvm")) api(ktor("ktor-client-logging-jvm")) + testImplementation(ktor("ktor-client-mock")) + testImplementation("org.jetbrains.kotlin:kotlin-test-junit" version { kotlin }) } diff --git a/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/RasaIntentActivator.kt b/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/RasaIntentActivator.kt index be85fd7c..8bc2e6bf 100644 --- a/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/RasaIntentActivator.kt +++ b/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/RasaIntentActivator.kt @@ -10,6 +10,7 @@ import com.justai.jaicf.api.BotRequest import com.justai.jaicf.api.hasQuery import com.justai.jaicf.context.BotContext import com.justai.jaicf.model.scenario.ScenarioModel +import kotlinx.serialization.json.jsonObject import java.util.* class RasaIntentActivator( @@ -24,7 +25,8 @@ class RasaIntentActivator( override fun recogniseIntent(botContext: BotContext, request: BotRequest): List { val messageId = UUID.randomUUID().toString() - val json = api.parseMessage(RasaParseMessageRequest(request.input, messageId)) ?: return emptyList() + val rawJson = api.parseMessage(RasaParseMessageRequest(request.input, messageId)) ?: return emptyList() + val json = api.Json.parseToJsonElement(rawJson).jsonObject val response = api.Json.decodeFromJsonElement(RasaParseMessageResponse.serializer(), json) response.ranking ?: return emptyList() diff --git a/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaApi.kt b/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaApi.kt index cb9abcfa..86727751 100644 --- a/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaApi.kt +++ b/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaApi.kt @@ -3,6 +3,7 @@ package com.justai.jaicf.activator.rasa.api import com.justai.jaicf.helpers.http.toUrl import com.justai.jaicf.helpers.logging.WithLogger import io.ktor.client.HttpClient +import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.cio.CIO import io.ktor.client.features.json.JsonFeature import io.ktor.client.features.json.serializer.KotlinxSerializer @@ -12,15 +13,16 @@ import io.ktor.http.ContentType import io.ktor.http.contentType import kotlinx.coroutines.runBlocking import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonObject class RasaApi( - private val uri: String + private val uri: String, + logLevel: LogLevel = LogLevel.INFO, + httpClient: HttpClientEngine = CIO.create() ) : WithLogger { internal val Json = Json { ignoreUnknownKeys = true; isLenient = true } - private val client = HttpClient(CIO) { + private val client = HttpClient(httpClient) { expectSuccess = true install(JsonFeature) { @@ -28,13 +30,13 @@ class RasaApi( } install(Logging) { - level = LogLevel.INFO + level = logLevel } } - fun parseMessage(request: RasaParseMessageRequest): JsonObject? = runBlocking { + fun parseMessage(request: RasaParseMessageRequest): String? = runBlocking { try { - client.post("$uri/model/parse".toUrl()) { + client.post("$uri/model/parse".toUrl()) { contentType(ContentType.Application.Json) body = request } diff --git a/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaParseMessageResponse.kt b/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaParseMessageResponse.kt index c1c04eaf..a70b6b35 100644 --- a/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaParseMessageResponse.kt +++ b/activators/rasa/src/main/kotlin/com/justai/jaicf/activator/rasa/api/RasaParseMessageResponse.kt @@ -6,18 +6,18 @@ import kotlinx.serialization.Serializable @Serializable data class RasaParseMessageResponse( val text: String, - val intent: Intent?, - val entities: List?, + val intent: Intent? = null, + val entities: List? = null, @SerialName("intent_ranking") - val ranking: List? + val ranking: List? = null ) @Serializable data class Entity( val start: Int, val end: Int, - val confidence: Float?, + val confidence: Float? = null, val value: String, val entity: String ) diff --git a/activators/rasa/src/test/kotlin/com/justai/jaicf/activator/rasa/test/RasaDeserializationTest.kt b/activators/rasa/src/test/kotlin/com/justai/jaicf/activator/rasa/test/RasaDeserializationTest.kt new file mode 100644 index 00000000..b0f49b4c --- /dev/null +++ b/activators/rasa/src/test/kotlin/com/justai/jaicf/activator/rasa/test/RasaDeserializationTest.kt @@ -0,0 +1,132 @@ +package com.justai.jaicf.activator.rasa.test + +import com.justai.jaicf.activator.Activator +import com.justai.jaicf.activator.rasa.RasaActivatorContext +import com.justai.jaicf.activator.rasa.RasaIntentActivator +import com.justai.jaicf.activator.rasa.api.RasaApi +import com.justai.jaicf.activator.rasa.api.RasaParseMessageRequest +import com.justai.jaicf.activator.rasa.rasa +import com.justai.jaicf.activator.selection.ActivationSelector +import com.justai.jaicf.api.QueryBotRequest +import com.justai.jaicf.builder.Scenario +import com.justai.jaicf.context.BotContext +import com.justai.jaicf.context.DialogContext +import io.ktor.client.engine.mock.* +import io.ktor.http.* +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.floatOrNull +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import org.intellij.lang.annotations.Language +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class RasaDeserializationTest { + + private fun api(response: String) = RasaApi("", httpClient = MockEngine { respond(response.toByteArray()) }) + + private fun activator(response: String) = RasaIntentActivator.Factory(api(response)).create( + Scenario { + state("hello") { + activators { + intent("Hello") + } + } + }.model + ) + + private fun rasaResponse(entities: String) = """ + { + "text": "Sample", + "intents": [ + { + "name": "Hello", + "confidence": 0.5 + } + ], + "intent_ranking": [ + { + "name": "Hello", + "confidence": 0.5 + } + ], + $entities + } + """.trimIndent() + + private fun Activator.activate() = assertNotNull( + activate( + BotContext("", DialogContext()), + QueryBotRequest("", ""), + ActivationSelector.default + )?.context as? RasaActivatorContext + ) + + @Test + fun `Should deserialize entity`() { + val ctx = activator(rasaResponse(""" + "entities": [ + { + "start": 0, + "end": 10, + "confidence": 0.5, + "value": "value", + "entity": "entity" + } + ] + """.trimIndent() + )).activate() + + val entity = assertNotNull(ctx.entities.find { it.entity == "entity" }) + assertEquals(0.5f, entity.confidence) + } + + @Test + fun `Should deserialize entity without confidence`() { + val ctx = activator(rasaResponse(""" + "entities": [ + { + "start": 0, + "end": 10, + "value": "value", + "entity": "entity" + } + ] + """.trimIndent() + )).activate() + + val entity = assertNotNull(ctx.entities.find { it.entity == "entity" }) + assertEquals(null, entity.confidence) + } + + @Test + fun `Should store raw request in ActivatorContext`() { + val ctx = activator(rasaResponse(""" + "entities": [ + { + "start": 0, + "end": 10, + "confidence_entity": 0.5, + "value": "value", + "entity": "entity" + } + ] + """.trimIndent() + )).activate() + + val entity = assertNotNull(ctx.entities.find { it.entity == "entity" }) + assertEquals(null, entity.confidence) + assertEquals( + 0.5f, + ctx.rawResponse["entities"]?.jsonArray + ?.find { it.jsonObject["entity"]?.jsonPrimitive?.contentOrNull == "entity" }?.jsonObject + ?.get("confidence_entity")?.jsonPrimitive?.floatOrNull + ) + } +} \ No newline at end of file