diff --git a/src/fstalign.cpp b/src/fstalign.cpp index 55262ab..0ffbcb6 100644 --- a/src/fstalign.cpp +++ b/src/fstalign.cpp @@ -678,9 +678,10 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine RecordSpeakerSwitchWer(stitches, alignerOptions.speaker_switch_context_size); } - // Calculate and record per-speaker WER + // Calculate and record supplementary WER RecordSpeakerWer(stitches); RecordTagWer(stitches); + RecordSentenceWer(stitches); if (!output_nlp.empty()) { ofstream nlp_ostream(output_nlp); diff --git a/src/wer.cpp b/src/wer.cpp index 8198790..7669701 100644 --- a/src/wer.cpp +++ b/src/wer.cpp @@ -74,6 +74,34 @@ void RecordWer(spWERA topAlignment) { } } +void RecordSentenceWer(vector> stitches) { + std::set eos_punc{".", "?", "!"}; + vector sentence_wers; + WerResult curr_wer = {0, 0, 0, 0, 0}; + for (auto &stitch : stitches) { + curr_wer.deletions += stitch->comment.rfind("del", 0) == 0; + curr_wer.insertions += stitch->comment.rfind("ins", 0) == 0; + curr_wer.substitutions += stitch->comment.rfind("sub", 0) == 0; + curr_wer.numWordsInReference += stitch->comment.rfind("ins", 0) != 0; + + // Check if we hit EOS + if (eos_punc.find(stitch->nlpRow.punctuation) != eos_punc.end()) { + sentence_wers.push_back(curr_wer); + curr_wer = {0, 0, 0, 0, 0}; + } + } + // Add last one if its empty case + if (curr_wer.numWordsInReference > 0) { + sentence_wers.push_back(curr_wer); + } + + // Add to log + for (int i=0; i < sentence_wers.size(); i++) { + RecordWerResult(jsonLogger::JsonLogger::getLogger().root["wer"]["sentenceWer"][i], sentence_wers[i]); + } +} + + void RecordSpeakerWer(vector> stitches) { // Note: stitches must have already been aligned to NLP rows // Logic for segment boundaries copied from speaker switch WER code diff --git a/src/wer.h b/src/wer.h index bf42d6d..e75b192 100644 --- a/src/wer.h +++ b/src/wer.h @@ -39,6 +39,7 @@ void RecordWerResult(Json::Value &json, WerResult wr); void RecordWer(spWERA topAlignment); void RecordSpeakerWer(vector> stitches); void RecordSpeakerSwitchWer(vector> stitches, int speaker_switch_context_size); +void RecordSentenceWer(vector> stitches); void RecordTagWer(vector> stitches); void RecordCaseWer(vector> aligned_stitches);