diff --git a/Dockerfile b/Dockerfile index 75de78b..347925f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ # Using kaldi image for pre-built OpenFST, version is 1.7.2 -FROM kaldiasr/kaldi:latest as kaldi-base +FROM kaldiasr/kaldi:cpu-debian10-2024-07-29 as kaldi-base + FROM debian:11 COPY --from=kaldi-base /opt/kaldi/tools/openfst /opt/openfst diff --git a/src/fstalign.cpp b/src/fstalign.cpp index 37ee686..cd984bf 100644 --- a/src/fstalign.cpp +++ b/src/fstalign.cpp @@ -242,15 +242,15 @@ vector make_stitches(wer_alignment &alignment, vector h stitches.emplace_back(); Stitching &part = stitches.back(); part.classLabel = tk_classLabel; - part.reftk = ref_tk; - part.hyptk = hyp_tk; + part.reftk = {ref_tk}; + part.hyptk = {hyp_tk}; bool del = false, ins = false, sub = false; if (ref_tk == INS) { part.comment = "ins"; } else if (hyp_tk == DEL) { part.comment = "del"; } else if (hyp_tk != ref_tk) { - part.comment = "sub(" + part.hyptk + ")"; + part.comment = "sub(" + part.hyptk.token + ")"; } // for classes, we will have only one token in the global vector @@ -281,10 +281,10 @@ vector make_stitches(wer_alignment &alignment, vector h if (!hyp_ctm_rows.empty()) { auto ctmPart = hyp_ctm_rows[hypRowIndex]; - part.start_ts = ctmPart.start_time_secs; - part.duration = ctmPart.duration_secs; - part.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs; - part.confidence = ctmPart.confidence; + part.hyptk.start_ts = ctmPart.start_time_secs; + part.hyptk.duration = ctmPart.duration_secs; + part.hyptk.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs; + part.hyptk.confidence = ctmPart.confidence; part.hyp_orig = ctmPart.word; // sanity check @@ -308,21 +308,24 @@ vector make_stitches(wer_alignment &alignment, vector h float ts = stof(hypNlpPart.ts); float endTs = stof(hypNlpPart.endTs); - part.start_ts = ts; - part.end_ts = endTs; - part.duration = endTs - ts; + part.hyptk.start_ts = ts; + part.hyptk.end_ts = endTs; + part.hyptk.duration = endTs - ts; } else if (!hypNlpPart.ts.empty()) { float ts = stof(hypNlpPart.ts); - part.start_ts = ts; - part.end_ts = ts; - part.duration = 0.0; + part.hyptk.start_ts = ts; + part.hyptk.end_ts = ts; + part.hyptk.duration = 0.0; } else if (!hypNlpPart.endTs.empty()) { float endTs = stof(hypNlpPart.endTs); - part.start_ts = endTs; - part.end_ts = endTs; - part.duration = 0.0; + part.hyptk.start_ts = endTs; + part.hyptk.end_ts = endTs; + part.hyptk.duration = 0.0; + } + if (!hypNlpPart.confidence.empty()) { + part.hyptk.confidence = stof(hypNlpPart.confidence); } } @@ -575,15 +578,15 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil // if the comment starts with 'ins' if (stitch.comment.find("ins") == 0 && !add_inserts) { // there's no nlp row info for such case, let's skip over it - if (stitch.confidence >= 1) { - logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts); + if (stitch.hyptk.confidence >= 1) { + logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk.token, stitch.hyptk.start_ts); } continue; } string original_nlp_token = stitch.nlpRow.token; - string ref_tk = stitch.reftk; + string ref_tk = stitch.reftk.token; // trying to salvage some of the original punctuation in a relatively safe manner if (iequals(ref_tk, original_nlp_token)) { @@ -597,9 +600,9 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil ref_tk = original_nlp_token; } else if (stitch.comment.find("ins") == 0) { assert(add_inserts); - logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment); + logger->debug("an insertion was found for {} {}", stitch.hyptk.token, stitch.comment); ref_tk = ""; - stitch.comment = "ins(" + stitch.hyptk + ")"; + stitch.comment = "ins(" + stitch.hyptk.token + ")"; } if (ref_tk == NOOP) { @@ -607,11 +610,11 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil } output_nlp_file << ref_tk << "|" << stitch.nlpRow.speakerId << "|"; - if (stitch.hyptk == DEL) { + if (stitch.hyptk.token == DEL) { // we have no ts/endTs data to put... output_nlp_file << "||"; } else { - output_nlp_file << fmt::format("{0:.4f}", stitch.start_ts) << "|" << fmt::format("{0:.4f}", stitch.end_ts) + output_nlp_file << fmt::format("{0:.4f}", stitch.hyptk.start_ts) << "|" << fmt::format("{0:.4f}", stitch.hyptk.end_ts) << "|"; } @@ -632,7 +635,7 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil } void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp, - AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case) { + AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector ref_extra_columns, std::vector hyp_extra_columns) { // int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename, // string composition_approach, bool record_case_stats) { auto logger = logger::GetOrCreateLogger("fstalign"); @@ -698,7 +701,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine JsonLogUnigramBigramStats(topAlignment); if (!output_sbs.empty()) { logger->info("output_sbs = {}", output_sbs); - WriteSbs(topAlignment, stitches, output_sbs); + WriteSbs(topAlignment, stitches, output_sbs, ref_extra_columns, hyp_extra_columns); } if (!output_nlp.empty() && !nlp_ref_loader) { @@ -720,3 +723,15 @@ void HandleAlign(NlpFstLoader& refLoader, CtmFstLoader& hypLoader, SynonymEngine align_stitches_to_nlp(refLoader, stitches); write_stitches_to_nlp(stitches, output_nlp_file, refLoader.mJsonNorm); } + +string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property) { + std::unordered_map> col_name_to_val = { + {"speaker", [](Token tk) {return tk.speaker;}}, + {"ts", [](Token tk) {return to_string(tk.start_ts);}}, + {"endTs", [](Token tk) {return to_string(tk.end_ts);}}, + {"confidence", [](Token tk) {return to_string(tk.confidence);}}, + }; + if (refToken) return col_name_to_val[property](stitch.reftk); + if (!refToken) return col_name_to_val[property](stitch.hyptk); + return ""; +} diff --git a/src/fstalign.h b/src/fstalign.h index 0320785..8af14f6 100644 --- a/src/fstalign.h +++ b/src/fstalign.h @@ -15,15 +15,21 @@ fstalign.h using namespace std; using namespace fst; +// Represent information associated with a reference or hypothesis token +struct Token { + string token; + float start_ts=0.0; + float end_ts=0.0; + float duration=0.0; + float confidence=-1.0; + string speaker; +}; + // Stitchings will be used to represent fstalign output, combining reference, // hypothesis, and error information into a record-like data structure. struct Stitching { - string reftk; - string hyptk; - float start_ts; - float end_ts; - float duration; - float confidence; + Token reftk; + Token hyptk; string classLabel; RawNlpRecord nlpRow; string hyp_orig; @@ -42,17 +48,12 @@ struct AlignerOptions { int levenstein_maximum_error_streak = 100; }; -// original -// void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine, string output_sbs, string -// output_nlp, -// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename, -// string composition_approach, bool record_case_stats); -// void HandleAlign(NlpFstLoader *refLoader, CtmFstLoader *hypLoader, SynonymEngine *engine, ofstream &output_nlp_file, -// int numBests, string symbols_filename, string composition_approach); void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp, - AlignerOptions alignerOptions, bool add_inserts_nlp = false, bool use_case = false); + AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector ref_extra_columns, std::vector hyp_extra_columns); void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file, AlignerOptions alignerOptions); +string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property); + #endif // __FSTALIGN_H__ diff --git a/src/main.cpp b/src/main.cpp index d87ddbc..bf03942 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -40,6 +40,9 @@ int main(int argc, char **argv) { bool disable_cutoffs = false; bool disable_hyphen_ignore = false; + std::vector ref_extra_columns = std::vector(); + std::vector hyp_extra_columns = std::vector(); + CLI::App app("Rev FST Align"); app.set_help_all_flag("--help-all", "Expand all help"); app.add_flag("--version", version, "Show fstalign version."); @@ -97,6 +100,10 @@ int main(int argc, char **argv) { c->add_option("--composition-approach", composition_approach, "Desired composition logic. Choices are 'standard' or 'adapted'"); + c->add_option("--ref-extra-cols", ref_extra_columns, + "Extra columns from the reference to include in SBS output."); + c->add_option("--hyp-extra-cols", hyp_extra_columns, + "Extra columns from the hypothesis to include in SBS output."); } get_wer->add_option("--wer-sidecar", wer_sidecar_filename, "WER sidecar json file."); @@ -180,7 +187,7 @@ int main(int argc, char **argv) { } if (command == "wer") { - HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case); + HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case, ref_extra_columns, hyp_extra_columns); } else if (command == "align") { if (output_nlp.empty()) { console->error("the output nlp file must be specified"); diff --git a/src/version.h b/src/version.h index 01137a8..5489d6d 100644 --- a/src/version.h +++ b/src/version.h @@ -1,5 +1,5 @@ #pragma once #define FSTALIGNER_VERSION_MAJOR 1 -#define FSTALIGNER_VERSION_MINOR 13 +#define FSTALIGNER_VERSION_MINOR 14 #define FSTALIGNER_VERSION_PATCH 0 diff --git a/src/wer.cpp b/src/wer.cpp index 1b2f066..f326c5d 100644 --- a/src/wer.cpp +++ b/src/wer.cpp @@ -262,8 +262,8 @@ void RecordCaseWer(const vector &aligned_stitches) { for (const auto &stitch : aligned_stitches) { const string &hyp = stitch.hyp_orig; const string &ref = stitch.nlpRow.token; - const string &reftk = stitch.reftk; - const string &hyptk = stitch.hyptk; + const string &reftk = stitch.reftk.token; + const string &hyptk = stitch.hyptk.token; const string &ref_casing = stitch.nlpRow.casing; if (hyptk == DEL || reftk == INS) { @@ -526,7 +526,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp) hyp = ""; } -void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename) { +void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename, const vector extra_ref_columns, const vector extra_hyp_columns) { auto logger = logger::GetOrCreateLogger("wer"); logger->set_level(spdlog::level::info); @@ -536,7 +536,14 @@ void WriteSbs(wer_alignment &topAlignment, const vector& stitches, st AlignmentTraversor visitor(topAlignment); string prev_tk_classLabel = ""; logger->info("Side-by-Side alignment info going into {}", sbs_filename); - myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities") << endl; + myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities"); + for (string col_name: extra_ref_columns) { + myfile << fmt::format("\tref_{0}", col_name); + } + for (string col_name: extra_hyp_columns) { + myfile << fmt::format("\thyp_{0}", col_name); + } + myfile << endl; // keep track of error groupings ErrorGroups groups_err; @@ -554,8 +561,8 @@ void WriteSbs(wer_alignment &topAlignment, const vector& stitches, st for (auto wer_tag: wer_tags) { tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|"; } - string ref_tk = p_stitch.reftk; - string hyp_tk = p_stitch.hyptk; + string ref_tk = p_stitch.reftk.token; + string hyp_tk = p_stitch.hyptk.token; string tag = ""; if (ref_tk == NOOP) { @@ -587,7 +594,15 @@ void WriteSbs(wer_alignment &topAlignment, const vector& stitches, st eff_class = tk_classLabel; } - myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags) << endl; + myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags); + + for (string col_name: extra_ref_columns) { + myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, true, col_name)); + } + for (string col_name: extra_hyp_columns) { + myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, false, col_name)); + } + myfile << endl; offset++; } diff --git a/src/wer.h b/src/wer.h index f0e9f35..b66e978 100644 --- a/src/wer.h +++ b/src/wer.h @@ -49,5 +49,5 @@ void CalculatePrecisionRecall(wer_alignment &topAlignment, int threshold); typedef vector> ErrorGroups; void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp); -void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename); +void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename, const vector extra_ref_columns, const vector extra_hyp_columns); void JsonLogUnigramBigramStats(wer_alignment &topAlignment);