Skip to content

Commit

Permalink
handle transcription failure output
Browse files Browse the repository at this point in the history
  • Loading branch information
marjisound committed Oct 4, 2024
1 parent 610c558 commit 6696bbe
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 44 deletions.
63 changes: 50 additions & 13 deletions backend/app/extraction/ExternalTranscriptionExtractor.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package extraction

import com.amazonaws.services.sqs.AmazonSQS
import com.amazonaws.services.sqs.model.SendMessageRequest
import com.amazonaws.services.sqs.model.{MessageAttributeValue, SendMessageRequest}
import model.manifest.Blob
import org.joda.time.DateTime
import play.api.libs.json.Json
import play.api.libs.json.{JsError, JsResult, JsValue, Json, Reads}
import services.index.Index
import services.{ObjectStorage, TranscribeConfig}
import utils._
import utils.attempt.Failure

import java.util.UUID
import scala.concurrent.ExecutionContext
import scala.jdk.CollectionConverters.MapHasAsJava

case class SignedUrl(url: String, key: String)

Expand All @@ -34,18 +35,51 @@ object TranscriptionJob {
implicit val formats = Json.format[TranscriptionJob]
}

/**
* id: z.string(),
* originalFilename: z.string(),
* userEmail: z.string(),
* status: z.literal('SUCCESS'),
* languageCode: z.string(),
* outputBucketKeys: OutputBucketKeys,
*/
sealed trait TranscriptionOutput {
def id: String
def originalFilename: String
def userEmail: String
def isTranslation: Boolean
def status: String
}

case class TranscriptionOutputSuccess(
id: String,
originalFilename: String,
userEmail: String,
isTranslation: Boolean,
status: String = "SUCCESS",
languageCode: String,
outputBucketKeys: OutputBucketKeys
) extends TranscriptionOutput

case class TranscriptionOutputFailure(
id: String,
originalFilename: String,
userEmail: String,
isTranslation: Boolean,
status: String = "FAILURE"
) extends TranscriptionOutput

object TranscriptionOutputSuccess {
implicit val format = Json.format[TranscriptionOutputSuccess]
}

object TranscriptionOutputFailure {
implicit val format = Json.format[TranscriptionOutputFailure]
}

case class TranscriptionOutput(id: String, originalFilename: String, userEmail: String, status: String, languageCode: String, outputBucketKeys: OutputBucketKeys)
object TranscriptionOutput {
implicit val formats = Json.format[TranscriptionOutput]
// Custom Reads to handle both message types
implicit val transcriptionOutputReads: Reads[TranscriptionOutput] = new Reads[TranscriptionOutput] {
def reads(json: JsValue): JsResult[TranscriptionOutput] = {
(json \ "status").as[String] match {
case "SUCCESS" => json.validate[TranscriptionOutputSuccess]
case "FAILURE" => json.validate[TranscriptionOutputFailure]
case other => JsError(s"Unknown status type: $other")
}
}
}
}

class ExternalTranscriptionExtractor(index: Index, transcribeConfig: TranscribeConfig, transcriptionStorage: ObjectStorage, outputStorage: ObjectStorage, amazonSQSClient: AmazonSQS)(implicit executionContext: ExecutionContext) extends ExternalExtractor with Logging {
Expand Down Expand Up @@ -110,12 +144,15 @@ class ExternalTranscriptionExtractor(index: Index, transcribeConfig: TranscribeC
transcriptionJob.flatMap {
job => {
try {
logger.info(s"sending message to Transcription Service Queue")
logger.info(s"sending message to Transcription Service Queue with message attribute")

val sendMessageCommand = new SendMessageRequest()
.withQueueUrl(transcribeConfig.transcriptionServiceQueueUrl)
.withMessageBody(Json.stringify(Json.toJson(job)))
.withMessageGroupId(UUID.randomUUID().toString)
.withMessageAttributes(
Map("BlobId" -> new MessageAttributeValue().withDataType("String").withStringValue(blob.uri.value)).asJava
)
Right(amazonSQSClient.sendMessage(sendMessageCommand))
} catch {
case e: Failure => Left(e)
Expand Down
102 changes: 72 additions & 30 deletions backend/app/extraction/ExternalTranscriptionWorker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,91 @@ import cats.syntax.either._
import com.amazonaws.services.sqs.AmazonSQS
import com.amazonaws.services.sqs.model.{DeleteMessageRequest, Message, ReceiveMessageRequest}
import model.{English, Languages, Uri}
import play.api.libs.json.{Format, JsError, JsSuccess, Json}
import play.api.libs.json.{JsError, JsSuccess, Json}
import services.index.Index
import services.manifest.WorkerManifest
import services.{ObjectStorage, TranscribeConfig}
import utils.Logging
import utils.attempt.{ExternalTranscriptionFailure, JsonParseFailure}
import utils.attempt.{ExternalTranscriptionFailure, Failure, JsonParseFailure}

import java.nio.charset.StandardCharsets
import scala.concurrent.ExecutionContext
import scala.jdk.CollectionConverters.CollectionHasAsScala
import scala.util.Try


case class TranscriptionMessageAttribute(receiveCount: Int, messageGroupId: String)
class ExternalTranscriptionWorker(manifest: WorkerManifest, amazonSQSClient: AmazonSQS, transcribeConfig: TranscribeConfig, blobStorage: ObjectStorage, index: Index)(implicit executionContext: ExecutionContext) extends Logging{

def pollForResults(): Int = {
logger.info("Fetching messages from external transcription output queue")
logger.info(s"Fetching messages from external transcription output queue ${transcribeConfig.transcriptionOutputQueueUrl}")

val messages = amazonSQSClient.receiveMessage(
new ReceiveMessageRequest(transcribeConfig.transcriptionOutputQueueUrl).withMaxNumberOfMessages(10)
new ReceiveMessageRequest(transcribeConfig.transcriptionOutputQueueUrl)
.withMaxNumberOfMessages(10)
.withAttributeNames("MessageGroupId", "ApproximateReceiveCount")
).getMessages

if (messages.size() > 0)
logger.info(s"retrieved ${messages.size()} messages from queue Transcription Output Queue")
else logger.info("No message found")
else
logger.info("No message found")

messages.asScala.toList.foldLeft(0) { (completed, message) =>
val result = for {
transcriptionOutput <- parseMessage[TranscriptionOutput](message)
transcription <- blobStorage.get(transcriptionOutput.outputBucketKeys.text)
txt = new String(transcription.readAllBytes(), StandardCharsets.UTF_8)
_ <- addDocumentTranscription(transcriptionOutput, txt)
_ <- markAsComplete(transcriptionOutput.id, "ExternalTranscriptionExtractor")
} yield {
amazonSQSClient.deleteMessage(
new DeleteMessageRequest(transcribeConfig.transcriptionOutputQueueUrl, message.getReceiptHandle)
)
logger.debug(s"deleted message for ${transcriptionOutput.id}")
}

result match {
case Right(_) =>
completed + 1
case Left(failure) =>
logger.error(s"failed to process sqs message, ${failure.msg}", failure.toThrowable)
getMessageAttribute(message) match {
case Right(messageAttributes) =>
handleMessage(message, messageAttributes, completed)
case Left(error) =>
logger.error(s"Could not get message attributes from transcription output message, therefore can not update extractor. Message id: ${message.getMessageId}", error)
completed
}
}
}

private def markAsComplete(id: String, extractorName: String) = {

private def handleMessage(message: Message, messageAttributes: TranscriptionMessageAttribute, completed: Int) = {
val result = for {
transcriptionOutput <- parseMessage(message)
transcription <- blobStorage.get(transcriptionOutput.outputBucketKeys.text)
txt = new String(transcription.readAllBytes(), StandardCharsets.UTF_8)
_ <- addDocumentTranscription(transcriptionOutput, txt)
_ <- markExternalExtractorAsComplete(transcriptionOutput.id, "ExternalTranscriptionExtractor")
} yield {
amazonSQSClient.deleteMessage(
new DeleteMessageRequest(transcribeConfig.transcriptionOutputQueueUrl, message.getReceiptHandle)
)
logger.debug(s"deleted message for ${transcriptionOutput.id}")
}

result match {
case Right(_) =>
completed + 1
case Left(failure) =>
logger.error(s"failed to process sqs message", failure.toThrowable)
if (messageAttributes.receiveCount > 2) {
markAsFailure(new Uri(messageAttributes.messageGroupId), "ExternalTranscriptionExtractor", failure.msg)
}
completed
}
}

private def getMessageAttribute(message: Message) = {
Try {
val attributes = message.getAttributes
val receiveCount = attributes.get("ApproximateReceiveCount").toInt
val messageGroupId = attributes.get("MessageGroupId")
TranscriptionMessageAttribute(receiveCount, messageGroupId)
}.toEither
}

private def markExternalExtractorAsComplete(id: String, extractorName: String) = {
val result = manifest.markExternalAsComplete(id, extractorName)
result.leftMap { failure =>
logger.error(s"Failed to mark '${id}' processed by $extractorName as complete: ${failure.msg}")
failure
}
}

private def addDocumentTranscription(transcriptionOutput: TranscriptionOutput, text: String) = {
private def addDocumentTranscription(transcriptionOutput: TranscriptionOutputSuccess, text: String) = {
Either.catchNonFatal {
index.addDocumentTranscription(Uri(transcriptionOutput.originalFilename), text, None, Languages.getByIso6391Code(transcriptionOutput.languageCode).getOrElse(English))
.recoverWith {
Expand All @@ -75,12 +103,26 @@ class ExternalTranscriptionWorker(manifest: WorkerManifest, amazonSQSClient: Ama
}
}

private def parseMessage[T: Format](message: Message) = {
private def parseMessage(message: Message): Either[Failure, TranscriptionOutputSuccess] = {
val json = Json.parse(message.getBody)

Json.fromJson[T](json) match {
case JsSuccess(output, _) => Right(output)
case JsError(error) => Left(JsonParseFailure(error))
Json.fromJson[TranscriptionOutput](json) match {
case JsSuccess(output: TranscriptionOutputSuccess, _) =>
Right(output)

case JsSuccess(output: TranscriptionOutputFailure, _) =>
Left(ExternalTranscriptionFailure.apply(new Error(s"External transcription service failed to transcribe the file ${output.originalFilename}")))

case JsError(errors) =>
Left(JsonParseFailure(errors))
}
}

private def markAsFailure(uri: Uri, extractorName: String, failureMsg: String): Unit = {
logger.error(s"Error in '${extractorName} processing ${uri}': ${failureMsg}")

manifest.logExtractionFailure(uri, extractorName, failureMsg).left.foreach { f =>
logger.error(s"Failed to log extractor in manifest: ${f.msg}")
}
}
}
2 changes: 1 addition & 1 deletion backend/conf/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ ocr {
transcribe {
whisperModelFilename = "ggml-base.bin"
transcriptionServiceQueueUrl = "http://sqs.eu-west-1.localhost.localstack.cloud:4566/000000000000/transcription-service-task-queue-DEV.fifo"
transcriptionOutputQueueUrl = "http://sqs.eu-west-1.localhost.localstack.cloud:4566/000000000000/giant-output-queue-DEV"
transcriptionOutputQueueUrl = "http://sqs.eu-west-1.localhost.localstack.cloud:4566/000000000000/giant-output-queue-DEV.fifo"
}

sqs {
Expand Down

0 comments on commit 6696bbe

Please sign in to comment.