diff --git a/docs/Advanced-Usage.md b/docs/Advanced-Usage.md index 92498f0..e4a5dd7 100644 --- a/docs/Advanced-Usage.md +++ b/docs/Advanced-Usage.md @@ -72,6 +72,13 @@ Normalizations are a similar concept to synonyms. They allow a token or group of } ``` +### WER Sidecar + +CLI flag: `--wer-sidecar` + +Only usable for NLP format reference files. This passes a [WER sidecar](https://github.com/revdotcom/fstalign/blob/develop/docs//NLP-Format.md#wer-tag-sidecar) file to +add extra information to some outputs. Optional. + ## Outputs ### Text Log diff --git a/docs/NLP-Format.md b/docs/NLP-Format.md index e2593ff..40df5d8 100644 --- a/docs/NLP-Format.md +++ b/docs/NLP-Format.md @@ -25,4 +25,19 @@ first|0||||LC|['6:DATE']|['6'] quarter|0||||LC|['6:DATE']|['6'] 2020|0||||CA|['0:YEAR']|['0', '1', '6'] NexGEn|0||||MC|['7:ORG']|['7'] -``` \ No newline at end of file +``` + +## WER tag sidecar + +WER tag sidecar files contain accompanying info for tokens in an NLP file. The +keys are IDs corresponding to tokens in the NLP file `wer_tags` column. The +objects under the keys are information about the token. + +Example: +``` +{ + '0': {'entity_type': 'YEAR'}, + '1': {'entity_type': 'CARDINAL'}, + '6': {'entity_type': 'SPACY>TIME'}, +} +``` diff --git a/src/Nlp.cpp b/src/Nlp.cpp index c6fc40e..b11ba18 100644 --- a/src/Nlp.cpp +++ b/src/Nlp.cpp @@ -15,21 +15,36 @@ /*********************************** NLP FstLoader class start ************************************/ -NlpFstLoader::NlpFstLoader(std::vector &records, Json::Value normalization) - : NlpFstLoader(records, normalization, true) {} +NlpFstLoader::NlpFstLoader(std::vector &records, Json::Value normalization, + Json::Value wer_sidecar) + : NlpFstLoader(records, normalization, wer_sidecar, true) {} -NlpFstLoader::NlpFstLoader(std::vector &records, Json::Value normalization, bool processLabels) +NlpFstLoader::NlpFstLoader(std::vector &records, Json::Value normalization, + Json::Value wer_sidecar, bool processLabels) : FstLoader() { mNlpRows = records; mJsonNorm = normalization; + mWerSidecar = wer_sidecar; std::string last_label; bool firstTk = true; + // fuse multiple rows that have the same id/label into one entry only for (auto &row : mNlpRows) { auto curr_tk = row.token; auto curr_label = row.best_label; auto curr_label_id = row.best_label_id; + auto curr_row_tags = row.wer_tags; + // Update wer tags in records to real string labels + vector real_wer_tags; + for (auto &tag: curr_row_tags) { + auto real_tag = tag; + if (mWerSidecar != Json::nullValue) { + real_tag = "###"+ real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###"; + } + real_wer_tags.push_back(real_tag); + } + row.wer_tags = real_wer_tags; std::string speaker = row.speakerId; if (processLabels && curr_label != "") { diff --git a/src/Nlp.h b/src/Nlp.h index 774c3b6..8c9e65e 100644 --- a/src/Nlp.h +++ b/src/Nlp.h @@ -42,8 +42,8 @@ class NlpReader { class NlpFstLoader : public FstLoader { public: - NlpFstLoader(std::vector &records, Json::Value normalization, bool processLabels); - NlpFstLoader(std::vector &records, Json::Value normalization); + NlpFstLoader(std::vector &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels); + NlpFstLoader(std::vector &records, Json::Value normalization, Json::Value wer_sidecar); virtual ~NlpFstLoader(); virtual void addToSymbolTable(fst::SymbolTable &symbol) const; virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector map) const; @@ -53,6 +53,7 @@ class NlpFstLoader : public FstLoader { vector mNlpRows; vector mSpeakers; Json::Value mJsonNorm; + Json::Value mWerSidecar; virtual const std::string &getToken(int index) const { return mToken.at(index); } }; diff --git a/src/fstalign.cpp b/src/fstalign.cpp index 042ef19..55262ab 100644 --- a/src/fstalign.cpp +++ b/src/fstalign.cpp @@ -636,34 +636,29 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine CalculatePrecisionRecall(topAlignment, alignerOptions.pr_threshold); RecordWer(topAlignment); - if (!output_sbs.empty()) { - logger->info("output_sbs = {}", output_sbs); - WriteSbs(topAlignment, output_sbs); + vector> stitches; + CtmFstLoader *ctm_hyp_loader = dynamic_cast(hypLoader); + NlpFstLoader *nlp_hyp_loader = dynamic_cast(hypLoader); + OneBestFstLoader *best_loader = dynamic_cast(hypLoader); + if (ctm_hyp_loader) { + stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {}); + } else if (nlp_hyp_loader) { + stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows); + } else if (best_loader) { + vector tokens; + tokens.reserve(best_loader->TokensSize()); + for (int i = 0; i < best_loader->TokensSize(); i++) { + string token = best_loader->getToken(i); + tokens.push_back(token); + } + stitches = make_stitches(topAlignment, {}, {}, tokens); + } else { + stitches = make_stitches(topAlignment); } NlpFstLoader *nlp_ref_loader = dynamic_cast(refLoader); if (nlp_ref_loader) { // We have an NLP reference, more metadata (e.g. speaker info) is available - vector> stitches; - CtmFstLoader *ctm_hyp_loader = dynamic_cast(hypLoader); - NlpFstLoader *nlp_hyp_loader = dynamic_cast(hypLoader); - OneBestFstLoader *best_loader = dynamic_cast(hypLoader); - if (ctm_hyp_loader) { - stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {}); - } else if (nlp_hyp_loader) { - stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows); - } else if (best_loader) { - vector tokens; - tokens.reserve(best_loader->TokensSize()); - for (int i = 0; i < best_loader->TokensSize(); i++) { - string token = best_loader->getToken(i); - tokens.push_back(token); - } - stitches = make_stitches(topAlignment, {}, {}, tokens); - } else { - stitches = make_stitches(topAlignment); - } - // Align stitches to the NLP, so stitches can access metadata try { align_stitches_to_nlp(nlp_ref_loader, &stitches); @@ -693,6 +688,11 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine } } + if (!output_sbs.empty()) { + logger->info("output_sbs = {}", output_sbs); + WriteSbs(topAlignment, stitches, output_sbs); + } + if (!output_nlp.empty() && !nlp_ref_loader) { logger->warn("Attempted to output an Aligned NLP file without NLP reference, skipping output."); } diff --git a/src/main.cpp b/src/main.cpp index f6186f8..bad45eb 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -16,6 +16,7 @@ int main(int argc, char **argv) { setlocale(LC_ALL, "en_US.UTF-8"); string ref_filename; string json_norm_filename; + string wer_sidecar_filename; string hyp_filename; string log_filename = ""; string output_nlp = ""; @@ -94,6 +95,8 @@ int main(int argc, char **argv) { c->add_option("--composition-approach", composition_approach, "Desired composition logic. Choices are 'standard' or 'adapted'"); } + get_wer->add_option("--wer-sidecar", wer_sidecar_filename, + "WER sidecar json file."); get_wer->add_option("--speaker-switch-context", speaker_switch_context_size, "Amount of context (in each direction) around " @@ -166,6 +169,27 @@ int main(int argc, char **argv) { Json::parseFromStream(builder, ss, &obj, &errs); } + Json::Value wer_sidecar_obj; + if (!wer_sidecar_filename.empty()) { + console->info("reading wer sidecar info from {}", wer_sidecar_filename); + ifstream ifs(wer_sidecar_filename); + + Json::CharReaderBuilder builder; + builder["collectComments"] = false; + + JSONCPP_STRING errs; + Json::parseFromStream(builder, ifs, &wer_sidecar_obj, &errs); + + console->info("The json we just read [{}] has {} elements from its root", wer_sidecar_filename, wer_sidecar_obj.size()); + } else { + stringstream ss; + ss << "{}"; + + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + Json::parseFromStream(builder, ss, &wer_sidecar_obj, &errs); + } + Json::Value hyp_json_obj; if (!hyp_json_norm_filename.empty()) { console->info("reading hypothesis json norm info from {}", hyp_json_norm_filename); @@ -194,7 +218,7 @@ int main(int argc, char **argv) { NlpReader nlpReader = NlpReader(); console->info("reading reference nlp from {}", ref_filename); auto vec = nlpReader.read_from_disk(ref_filename); - NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, true); + NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, wer_sidecar_obj, true); ref = nlpFst; } else if (EndsWithCaseInsensitive(ref_filename, string(".ctm"))) { console->info("reading reference ctm from {}", ref_filename); @@ -212,11 +236,19 @@ int main(int argc, char **argv) { // loading "hypothesis" inputs if (EndsWithCaseInsensitive(hyp_filename, string(".nlp"))) { console->info("reading hypothesis nlp from {}", hyp_filename); + // Make empty json for wer sidecar + Json::Value hyp_empty_json; + stringstream ss; + ss << "{}"; + + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + Json::parseFromStream(builder, ss, &hyp_empty_json, &errs); NlpReader nlpReader = NlpReader(); auto vec = nlpReader.read_from_disk(hyp_filename); // for now, nlp files passed as hypothesis won't have their labels handled as such // this also mean that json normalization will be ignored - NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, false); + NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, hyp_empty_json, false); hyp = nlpFst; } else if (EndsWithCaseInsensitive(hyp_filename, string(".ctm"))) { console->info("reading hypothesis ctm from {}", hyp_filename); diff --git a/src/wer.cpp b/src/wer.cpp index b4253a1..8198790 100644 --- a/src/wer.cpp +++ b/src/wer.cpp @@ -327,16 +327,19 @@ void RecordTagWer(vector> stitches) { for (auto &stitch : stitches) { if (!stitch->nlpRow.wer_tags.empty()) { for (auto wer_tag : stitch->nlpRow.wer_tags) { - wer_results.insert(std::pair(wer_tag, {0, 0, 0, 0, 0})); + int tag_start = wer_tag.find_first_not_of('#'); + int tag_end = wer_tag.find('_'); + string wer_tag_id = wer_tag.substr(tag_start, tag_end - tag_start); + wer_results.insert(std::pair(wer_tag_id, {0, 0, 0, 0, 0})); // Check with rfind since other comments can be there bool del = stitch->comment.rfind("del", 0) == 0; bool ins = stitch->comment.rfind("ins", 0) == 0; bool sub = stitch->comment.rfind("sub", 0) == 0; - wer_results[wer_tag].insertions += ins; - wer_results[wer_tag].deletions += del; - wer_results[wer_tag].substitutions += sub; + wer_results[wer_tag_id].insertions += ins; + wer_results[wer_tag_id].deletions += del; + wer_results[wer_tag_id].substitutions += sub; if (!ins) { - wer_results[wer_tag].numWordsInReference += 1; + wer_results[wer_tag_id].numWordsInReference += 1; } } } @@ -503,7 +506,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp) hyp = ""; } -void WriteSbs(spWERA topAlignment, string sbs_filename) { +void WriteSbs(spWERA topAlignment, vector> stitches, string sbs_filename) { auto logger = logger::GetOrCreateLogger("wer"); logger->set_level(spdlog::level::info); @@ -514,7 +517,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) { triple *tk_pair = new triple(); string prev_tk_classLabel = ""; logger->info("Side-by-Side alignment info going into {}", sbs_filename); - myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}", "ref_token", "hyp_token", "IsErr", "Class") << endl; + myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities") << endl; // keep track of error groupings ErrorGroups groups_err; @@ -525,10 +528,15 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) { std::set op_set = {"", "", ""}; size_t offset = 2; // line number in output file where first triple starts - while (visitor.NextTriple(tk_pair)) { - string tk_classLabel = tk_pair->classLabel; - string ref_tk = tk_pair->ref; - string hyp_tk = tk_pair->hyp; + for (auto p_stitch: stitches) { + string tk_classLabel = p_stitch->classLabel; + string tk_wer_tags = ""; + auto wer_tags = p_stitch->nlpRow.wer_tags; + for (auto wer_tag: wer_tags) { + tk_wer_tags = tk_wer_tags + wer_tag + "|"; + } + string ref_tk = p_stitch->reftk; + string hyp_tk = p_stitch->hyptk; string tag = ""; if (ref_tk == NOOP) { @@ -560,7 +568,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) { eff_class = tk_classLabel; } - myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}", ref_tk, hyp_tk, tag, eff_class) << endl; + myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags) << endl; offset++; } diff --git a/src/wer.h b/src/wer.h index 78a1451..bf42d6d 100644 --- a/src/wer.h +++ b/src/wer.h @@ -48,4 +48,4 @@ void CalculatePrecisionRecall(spWERA &topAlignment, int threshold); typedef vector> ErrorGroups; void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp); -void WriteSbs(spWERA topAlignment, string sbs_filename); +void WriteSbs(spWERA topAlignment, vector> stitches, string sbs_filename); diff --git a/test/data/syn_1.hyp.sbs b/test/data/syn_1.hyp.sbs index 3a05e83..3e71f78 100644 --- a/test/data/syn_1.hyp.sbs +++ b/test/data/syn_1.hyp.sbs @@ -1,28 +1,28 @@ - ref_token hyp_token IsErr Class - we ERR - will we'll ERR - have have - a a - nice nice - evening evening - um ERR - no no - matter matter - what what - will will - happen happen - it ERR - um is ERR - it's uh ERR - a a - good good - opportunity opportunity - to to - do ERR - this this - you'll you'll - uh ERR - see see + ref_token hyp_token IsErr Class Wer_Tag_Entities + we ERR + will we'll ERR + have have + a a + nice nice + evening evening + um ERR + no no + matter matter + what what + will will + happen happen + it ERR + um is ERR + it's uh ERR + a a + good good + opportunity opportunity + to to + do ERR + this this + you'll you'll + uh ERR + see see ------------------------------------------------------------ Line Group 2 we will <-> we'll diff --git a/test/data/twenty.hyp-a2.sbs b/test/data/twenty.hyp-a2.sbs index d9caeee..7d49c95 100644 --- a/test/data/twenty.hyp-a2.sbs +++ b/test/data/twenty.hyp-a2.sbs @@ -1,13 +1,13 @@ - ref_token hyp_token IsErr Class - 20 ERR ___1_CARDINAL___ - in in - twenty twenty ___2_YEAR___ - twenty thirty ERR ___2_YEAR___ - is is - one one ___3_CARDINAL___ - twenty twenty ___3_CARDINAL___ - two ERR ___3_CARDINAL___ - three three ___3_CARDINAL___ + ref_token hyp_token IsErr Class Wer_Tag_Entities + 20 ERR ___1_CARDINAL___ + in in + twenty twenty ___2_YEAR___ + twenty thirty ERR ___2_YEAR___ + is is + one one ___3_CARDINAL___ + twenty twenty ___3_CARDINAL___ + two ERR ___3_CARDINAL___ + three three ___3_CARDINAL___ ------------------------------------------------------------ Line Group 2 20 <-> *** diff --git a/test/data/twenty.hyp.sbs b/test/data/twenty.hyp.sbs index 8b8e401..e75a474 100644 --- a/test/data/twenty.hyp.sbs +++ b/test/data/twenty.hyp.sbs @@ -1,13 +1,13 @@ - ref_token hyp_token IsErr Class - twenty ERR ___1_CARDINAL___ - in in - twenty twenty ___2_YEAR___ - twenty thirty ERR ___2_YEAR___ - is is - one one ___3_CARDINAL___ - twenty twenty ___3_CARDINAL___ - two ERR ___3_CARDINAL___ - three three ___3_CARDINAL___ + ref_token hyp_token IsErr Class Wer_Tag_Entities + twenty ERR ___1_CARDINAL___ + in in + twenty twenty ___2_YEAR___ + twenty thirty ERR ___2_YEAR___ + is is + one one ___3_CARDINAL___ + twenty twenty ___3_CARDINAL___ + two ERR ___3_CARDINAL___ + three three ___3_CARDINAL___ ------------------------------------------------------------ Line Group 2 twenty <-> ***