From 6696bbe720b8b370319f6853989883a30dba1b1c Mon Sep 17 00:00:00 2001 From: Marjan Kalanaki Date: Fri, 4 Oct 2024 16:53:26 +0100 Subject: [PATCH] handle transcription failure output --- .../ExternalTranscriptionExtractor.scala | 63 ++++++++--- .../ExternalTranscriptionWorker.scala | 102 ++++++++++++------ backend/conf/application.conf | 2 +- 3 files changed, 123 insertions(+), 44 deletions(-) diff --git a/backend/app/extraction/ExternalTranscriptionExtractor.scala b/backend/app/extraction/ExternalTranscriptionExtractor.scala index 5836a5a1..f821da10 100644 --- a/backend/app/extraction/ExternalTranscriptionExtractor.scala +++ b/backend/app/extraction/ExternalTranscriptionExtractor.scala @@ -1,10 +1,10 @@ package extraction import com.amazonaws.services.sqs.AmazonSQS -import com.amazonaws.services.sqs.model.SendMessageRequest +import com.amazonaws.services.sqs.model.{MessageAttributeValue, SendMessageRequest} import model.manifest.Blob import org.joda.time.DateTime -import play.api.libs.json.Json +import play.api.libs.json.{JsError, JsResult, JsValue, Json, Reads} import services.index.Index import services.{ObjectStorage, TranscribeConfig} import utils._ @@ -12,6 +12,7 @@ import utils.attempt.Failure import java.util.UUID import scala.concurrent.ExecutionContext +import scala.jdk.CollectionConverters.MapHasAsJava case class SignedUrl(url: String, key: String) @@ -34,18 +35,51 @@ object TranscriptionJob { implicit val formats = Json.format[TranscriptionJob] } -/** - * id: z.string(), - * originalFilename: z.string(), - * userEmail: z.string(), - * status: z.literal('SUCCESS'), - * languageCode: z.string(), - * outputBucketKeys: OutputBucketKeys, - */ +sealed trait TranscriptionOutput { + def id: String + def originalFilename: String + def userEmail: String + def isTranslation: Boolean + def status: String +} + +case class TranscriptionOutputSuccess( + id: String, + originalFilename: String, + userEmail: String, + isTranslation: Boolean, + status: String = "SUCCESS", + languageCode: String, + outputBucketKeys: OutputBucketKeys + ) extends TranscriptionOutput + +case class TranscriptionOutputFailure( + id: String, + originalFilename: String, + userEmail: String, + isTranslation: Boolean, + status: String = "FAILURE" + ) extends TranscriptionOutput + +object TranscriptionOutputSuccess { + implicit val format = Json.format[TranscriptionOutputSuccess] +} + +object TranscriptionOutputFailure { + implicit val format = Json.format[TranscriptionOutputFailure] +} -case class TranscriptionOutput(id: String, originalFilename: String, userEmail: String, status: String, languageCode: String, outputBucketKeys: OutputBucketKeys) object TranscriptionOutput { - implicit val formats = Json.format[TranscriptionOutput] + // Custom Reads to handle both message types + implicit val transcriptionOutputReads: Reads[TranscriptionOutput] = new Reads[TranscriptionOutput] { + def reads(json: JsValue): JsResult[TranscriptionOutput] = { + (json \ "status").as[String] match { + case "SUCCESS" => json.validate[TranscriptionOutputSuccess] + case "FAILURE" => json.validate[TranscriptionOutputFailure] + case other => JsError(s"Unknown status type: $other") + } + } + } } class ExternalTranscriptionExtractor(index: Index, transcribeConfig: TranscribeConfig, transcriptionStorage: ObjectStorage, outputStorage: ObjectStorage, amazonSQSClient: AmazonSQS)(implicit executionContext: ExecutionContext) extends ExternalExtractor with Logging { @@ -110,12 +144,15 @@ class ExternalTranscriptionExtractor(index: Index, transcribeConfig: TranscribeC transcriptionJob.flatMap { job => { try { - logger.info(s"sending message to Transcription Service Queue") + logger.info(s"sending message to Transcription Service Queue with message attribute") val sendMessageCommand = new SendMessageRequest() .withQueueUrl(transcribeConfig.transcriptionServiceQueueUrl) .withMessageBody(Json.stringify(Json.toJson(job))) .withMessageGroupId(UUID.randomUUID().toString) + .withMessageAttributes( + Map("BlobId" -> new MessageAttributeValue().withDataType("String").withStringValue(blob.uri.value)).asJava + ) Right(amazonSQSClient.sendMessage(sendMessageCommand)) } catch { case e: Failure => Left(e) diff --git a/backend/app/extraction/ExternalTranscriptionWorker.scala b/backend/app/extraction/ExternalTranscriptionWorker.scala index 41a4a0f7..8899347d 100644 --- a/backend/app/extraction/ExternalTranscriptionWorker.scala +++ b/backend/app/extraction/ExternalTranscriptionWorker.scala @@ -4,55 +4,83 @@ import cats.syntax.either._ import com.amazonaws.services.sqs.AmazonSQS import com.amazonaws.services.sqs.model.{DeleteMessageRequest, Message, ReceiveMessageRequest} import model.{English, Languages, Uri} -import play.api.libs.json.{Format, JsError, JsSuccess, Json} +import play.api.libs.json.{JsError, JsSuccess, Json} import services.index.Index import services.manifest.WorkerManifest import services.{ObjectStorage, TranscribeConfig} import utils.Logging -import utils.attempt.{ExternalTranscriptionFailure, JsonParseFailure} +import utils.attempt.{ExternalTranscriptionFailure, Failure, JsonParseFailure} import java.nio.charset.StandardCharsets import scala.concurrent.ExecutionContext import scala.jdk.CollectionConverters.CollectionHasAsScala +import scala.util.Try - +case class TranscriptionMessageAttribute(receiveCount: Int, messageGroupId: String) class ExternalTranscriptionWorker(manifest: WorkerManifest, amazonSQSClient: AmazonSQS, transcribeConfig: TranscribeConfig, blobStorage: ObjectStorage, index: Index)(implicit executionContext: ExecutionContext) extends Logging{ def pollForResults(): Int = { - logger.info("Fetching messages from external transcription output queue") + logger.info(s"Fetching messages from external transcription output queue ${transcribeConfig.transcriptionOutputQueueUrl}") + val messages = amazonSQSClient.receiveMessage( - new ReceiveMessageRequest(transcribeConfig.transcriptionOutputQueueUrl).withMaxNumberOfMessages(10) + new ReceiveMessageRequest(transcribeConfig.transcriptionOutputQueueUrl) + .withMaxNumberOfMessages(10) + .withAttributeNames("MessageGroupId", "ApproximateReceiveCount") ).getMessages if (messages.size() > 0) logger.info(s"retrieved ${messages.size()} messages from queue Transcription Output Queue") - else logger.info("No message found") + else + logger.info("No message found") messages.asScala.toList.foldLeft(0) { (completed, message) => - val result = for { - transcriptionOutput <- parseMessage[TranscriptionOutput](message) - transcription <- blobStorage.get(transcriptionOutput.outputBucketKeys.text) - txt = new String(transcription.readAllBytes(), StandardCharsets.UTF_8) - _ <- addDocumentTranscription(transcriptionOutput, txt) - _ <- markAsComplete(transcriptionOutput.id, "ExternalTranscriptionExtractor") - } yield { - amazonSQSClient.deleteMessage( - new DeleteMessageRequest(transcribeConfig.transcriptionOutputQueueUrl, message.getReceiptHandle) - ) - logger.debug(s"deleted message for ${transcriptionOutput.id}") - } - - result match { - case Right(_) => - completed + 1 - case Left(failure) => - logger.error(s"failed to process sqs message, ${failure.msg}", failure.toThrowable) + getMessageAttribute(message) match { + case Right(messageAttributes) => + handleMessage(message, messageAttributes, completed) + case Left(error) => + logger.error(s"Could not get message attributes from transcription output message, therefore can not update extractor. Message id: ${message.getMessageId}", error) completed } } } - private def markAsComplete(id: String, extractorName: String) = { + + private def handleMessage(message: Message, messageAttributes: TranscriptionMessageAttribute, completed: Int) = { + val result = for { + transcriptionOutput <- parseMessage(message) + transcription <- blobStorage.get(transcriptionOutput.outputBucketKeys.text) + txt = new String(transcription.readAllBytes(), StandardCharsets.UTF_8) + _ <- addDocumentTranscription(transcriptionOutput, txt) + _ <- markExternalExtractorAsComplete(transcriptionOutput.id, "ExternalTranscriptionExtractor") + } yield { + amazonSQSClient.deleteMessage( + new DeleteMessageRequest(transcribeConfig.transcriptionOutputQueueUrl, message.getReceiptHandle) + ) + logger.debug(s"deleted message for ${transcriptionOutput.id}") + } + + result match { + case Right(_) => + completed + 1 + case Left(failure) => + logger.error(s"failed to process sqs message", failure.toThrowable) + if (messageAttributes.receiveCount > 2) { + markAsFailure(new Uri(messageAttributes.messageGroupId), "ExternalTranscriptionExtractor", failure.msg) + } + completed + } + } + + private def getMessageAttribute(message: Message) = { + Try { + val attributes = message.getAttributes + val receiveCount = attributes.get("ApproximateReceiveCount").toInt + val messageGroupId = attributes.get("MessageGroupId") + TranscriptionMessageAttribute(receiveCount, messageGroupId) + }.toEither + } + + private def markExternalExtractorAsComplete(id: String, extractorName: String) = { val result = manifest.markExternalAsComplete(id, extractorName) result.leftMap { failure => logger.error(s"Failed to mark '${id}' processed by $extractorName as complete: ${failure.msg}") @@ -60,7 +88,7 @@ class ExternalTranscriptionWorker(manifest: WorkerManifest, amazonSQSClient: Ama } } - private def addDocumentTranscription(transcriptionOutput: TranscriptionOutput, text: String) = { + private def addDocumentTranscription(transcriptionOutput: TranscriptionOutputSuccess, text: String) = { Either.catchNonFatal { index.addDocumentTranscription(Uri(transcriptionOutput.originalFilename), text, None, Languages.getByIso6391Code(transcriptionOutput.languageCode).getOrElse(English)) .recoverWith { @@ -75,12 +103,26 @@ class ExternalTranscriptionWorker(manifest: WorkerManifest, amazonSQSClient: Ama } } - private def parseMessage[T: Format](message: Message) = { + private def parseMessage(message: Message): Either[Failure, TranscriptionOutputSuccess] = { val json = Json.parse(message.getBody) - Json.fromJson[T](json) match { - case JsSuccess(output, _) => Right(output) - case JsError(error) => Left(JsonParseFailure(error)) + Json.fromJson[TranscriptionOutput](json) match { + case JsSuccess(output: TranscriptionOutputSuccess, _) => + Right(output) + + case JsSuccess(output: TranscriptionOutputFailure, _) => + Left(ExternalTranscriptionFailure.apply(new Error(s"External transcription service failed to transcribe the file ${output.originalFilename}"))) + + case JsError(errors) => + Left(JsonParseFailure(errors)) + } + } + + private def markAsFailure(uri: Uri, extractorName: String, failureMsg: String): Unit = { + logger.error(s"Error in '${extractorName} processing ${uri}': ${failureMsg}") + + manifest.logExtractionFailure(uri, extractorName, failureMsg).left.foreach { f => + logger.error(s"Failed to log extractor in manifest: ${f.msg}") } } } diff --git a/backend/conf/application.conf b/backend/conf/application.conf index de52c838..75cfc308 100644 --- a/backend/conf/application.conf +++ b/backend/conf/application.conf @@ -208,7 +208,7 @@ ocr { transcribe { whisperModelFilename = "ggml-base.bin" transcriptionServiceQueueUrl = "http://sqs.eu-west-1.localhost.localstack.cloud:4566/000000000000/transcription-service-task-queue-DEV.fifo" - transcriptionOutputQueueUrl = "http://sqs.eu-west-1.localhost.localstack.cloud:4566/000000000000/giant-output-queue-DEV" + transcriptionOutputQueueUrl = "http://sqs.eu-west-1.localhost.localstack.cloud:4566/000000000000/giant-output-queue-DEV.fifo" } sqs {