Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow flexible forwarding of NLP columns to SBS #56

Merged
merged 8 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Using kaldi image for pre-built OpenFST, version is 1.7.2
FROM kaldiasr/kaldi:latest as kaldi-base
FROM kaldiasr/kaldi:cpu-debian10-2024-07-29 as kaldi-base

FROM debian:11

COPY --from=kaldi-base /opt/kaldi/tools/openfst /opt/openfst
Expand Down
65 changes: 40 additions & 25 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
stitches.emplace_back();
Stitching &part = stitches.back();
part.classLabel = tk_classLabel;
part.reftk = ref_tk;
part.hyptk = hyp_tk;
part.reftk = {ref_tk};
part.hyptk = {hyp_tk};
bool del = false, ins = false, sub = false;
if (ref_tk == INS) {
part.comment = "ins";
} else if (hyp_tk == DEL) {
part.comment = "del";
} else if (hyp_tk != ref_tk) {
part.comment = "sub(" + part.hyptk + ")";
part.comment = "sub(" + part.hyptk.token + ")";
}

// for classes, we will have only one token in the global vector
Expand Down Expand Up @@ -281,10 +281,10 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h

if (!hyp_ctm_rows.empty()) {
auto ctmPart = hyp_ctm_rows[hypRowIndex];
part.start_ts = ctmPart.start_time_secs;
part.duration = ctmPart.duration_secs;
part.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
part.confidence = ctmPart.confidence;
part.hyptk.start_ts = ctmPart.start_time_secs;
part.hyptk.duration = ctmPart.duration_secs;
part.hyptk.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
part.hyptk.confidence = ctmPart.confidence;

part.hyp_orig = ctmPart.word;
// sanity check
Expand All @@ -308,21 +308,24 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
float ts = stof(hypNlpPart.ts);
float endTs = stof(hypNlpPart.endTs);

part.start_ts = ts;
part.end_ts = endTs;
part.duration = endTs - ts;
part.hyptk.start_ts = ts;
part.hyptk.end_ts = endTs;
part.hyptk.duration = endTs - ts;
} else if (!hypNlpPart.ts.empty()) {
float ts = stof(hypNlpPart.ts);

part.start_ts = ts;
part.end_ts = ts;
part.duration = 0.0;
part.hyptk.start_ts = ts;
part.hyptk.end_ts = ts;
part.hyptk.duration = 0.0;
} else if (!hypNlpPart.endTs.empty()) {
float endTs = stof(hypNlpPart.endTs);

part.start_ts = endTs;
part.end_ts = endTs;
part.duration = 0.0;
part.hyptk.start_ts = endTs;
part.hyptk.end_ts = endTs;
part.hyptk.duration = 0.0;
}
if (!hypNlpPart.confidence.empty()) {
part.hyptk.confidence = stof(hypNlpPart.confidence);
}
}

Expand Down Expand Up @@ -575,15 +578,15 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
// if the comment starts with 'ins'
if (stitch.comment.find("ins") == 0 && !add_inserts) {
// there's no nlp row info for such case, let's skip over it
if (stitch.confidence >= 1) {
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts);
if (stitch.hyptk.confidence >= 1) {
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk.token, stitch.hyptk.start_ts);
}

continue;
}

string original_nlp_token = stitch.nlpRow.token;
string ref_tk = stitch.reftk;
string ref_tk = stitch.reftk.token;

// trying to salvage some of the original punctuation in a relatively safe manner
if (iequals(ref_tk, original_nlp_token)) {
Expand All @@ -597,21 +600,21 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
ref_tk = original_nlp_token;
} else if (stitch.comment.find("ins") == 0) {
assert(add_inserts);
logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment);
logger->debug("an insertion was found for {} {}", stitch.hyptk.token, stitch.comment);
ref_tk = "";
stitch.comment = "ins(" + stitch.hyptk + ")";
stitch.comment = "ins(" + stitch.hyptk.token + ")";
}

if (ref_tk == NOOP) {
continue;
}

output_nlp_file << ref_tk << "|" << stitch.nlpRow.speakerId << "|";
if (stitch.hyptk == DEL) {
if (stitch.hyptk.token == DEL) {
// we have no ts/endTs data to put...
output_nlp_file << "||";
} else {
output_nlp_file << fmt::format("{0:.4f}", stitch.start_ts) << "|" << fmt::format("{0:.4f}", stitch.end_ts)
output_nlp_file << fmt::format("{0:.4f}", stitch.hyptk.start_ts) << "|" << fmt::format("{0:.4f}", stitch.hyptk.end_ts)
<< "|";
}

Expand All @@ -632,7 +635,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
}

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case) {
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector<string> ref_extra_columns, std::vector<string> hyp_extra_columns) {
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
// string composition_approach, bool record_case_stats) {
auto logger = logger::GetOrCreateLogger("fstalign");
Expand Down Expand Up @@ -698,7 +701,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
JsonLogUnigramBigramStats(topAlignment);
if (!output_sbs.empty()) {
logger->info("output_sbs = {}", output_sbs);
WriteSbs(topAlignment, stitches, output_sbs);
WriteSbs(topAlignment, stitches, output_sbs, ref_extra_columns, hyp_extra_columns);
}

if (!output_nlp.empty() && !nlp_ref_loader) {
Expand All @@ -720,3 +723,15 @@ void HandleAlign(NlpFstLoader& refLoader, CtmFstLoader& hypLoader, SynonymEngine
align_stitches_to_nlp(refLoader, stitches);
write_stitches_to_nlp(stitches, output_nlp_file, refLoader.mJsonNorm);
}

string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property) {
std::unordered_map<std::string, std::function<string(Token)>> col_name_to_val = {
{"speaker", [](Token tk) {return tk.speaker;}},
{"ts", [](Token tk) {return to_string(tk.start_ts);}},
{"endTs", [](Token tk) {return to_string(tk.end_ts);}},
{"confidence", [](Token tk) {return to_string(tk.confidence);}},
};
if (refToken) return col_name_to_val[property](stitch.reftk);
if (!refToken) return col_name_to_val[property](stitch.hyptk);
return "";
}
29 changes: 15 additions & 14 deletions src/fstalign.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ fstalign.h
using namespace std;
using namespace fst;

// Represent information associated with a reference or hypothesis token
struct Token {
string token;
float start_ts=0.0;
float end_ts=0.0;
float duration=0.0;
float confidence=-1.0;
string speaker;
};

// Stitchings will be used to represent fstalign output, combining reference,
// hypothesis, and error information into a record-like data structure.
struct Stitching {
string reftk;
string hyptk;
float start_ts;
float end_ts;
float duration;
float confidence;
Token reftk;
Token hyptk;
string classLabel;
RawNlpRecord nlpRow;
string hyp_orig;
Expand All @@ -42,17 +48,12 @@ struct AlignerOptions {
int levenstein_maximum_error_streak = 100;
};

// original
// void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine, string output_sbs, string
// output_nlp,
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
// string composition_approach, bool record_case_stats);
// void HandleAlign(NlpFstLoader *refLoader, CtmFstLoader *hypLoader, SynonymEngine *engine, ofstream &output_nlp_file,
// int numBests, string symbols_filename, string composition_approach);

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp = false, bool use_case = false);
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector<string> ref_extra_columns, std::vector<string> hyp_extra_columns);
void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file,
AlignerOptions alignerOptions);

string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property);

#endif // __FSTALIGN_H__
9 changes: 8 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ int main(int argc, char **argv) {
bool disable_cutoffs = false;
bool disable_hyphen_ignore = false;

std::vector<string> ref_extra_columns = std::vector<string>();
std::vector<string> hyp_extra_columns = std::vector<string>();

CLI::App app("Rev FST Align");
app.set_help_all_flag("--help-all", "Expand all help");
app.add_flag("--version", version, "Show fstalign version.");
Expand Down Expand Up @@ -97,6 +100,10 @@ int main(int argc, char **argv) {

c->add_option("--composition-approach", composition_approach,
"Desired composition logic. Choices are 'standard' or 'adapted'");
c->add_option("--ref-extra-cols", ref_extra_columns,
"Extra columns from the reference to include in SBS output.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these space delimited?

c->add_option("--hyp-extra-cols", hyp_extra_columns,
"Extra columns from the hypothesis to include in SBS output.");
}
get_wer->add_option("--wer-sidecar", wer_sidecar_filename,
"WER sidecar json file.");
Expand Down Expand Up @@ -180,7 +187,7 @@ int main(int argc, char **argv) {
}

if (command == "wer") {
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case);
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case, ref_extra_columns, hyp_extra_columns);
} else if (command == "align") {
if (output_nlp.empty()) {
console->error("the output nlp file must be specified");
Expand Down
2 changes: 1 addition & 1 deletion src/version.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#define FSTALIGNER_VERSION_MAJOR 1
#define FSTALIGNER_VERSION_MINOR 13
#define FSTALIGNER_VERSION_MINOR 14
#define FSTALIGNER_VERSION_PATCH 0
29 changes: 22 additions & 7 deletions src/wer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ void RecordCaseWer(const vector<Stitching> &aligned_stitches) {
for (const auto &stitch : aligned_stitches) {
const string &hyp = stitch.hyp_orig;
const string &ref = stitch.nlpRow.token;
const string &reftk = stitch.reftk;
const string &hyptk = stitch.hyptk;
const string &reftk = stitch.reftk.token;
const string &hyptk = stitch.hyptk.token;
const string &ref_casing = stitch.nlpRow.casing;

if (hyptk == DEL || reftk == INS) {
Expand Down Expand Up @@ -526,7 +526,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp)
hyp = "";
}

void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename) {
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns) {
auto logger = logger::GetOrCreateLogger("wer");
logger->set_level(spdlog::level::info);

Expand All @@ -536,7 +536,14 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
AlignmentTraversor visitor(topAlignment);
string prev_tk_classLabel = "";
logger->info("Side-by-Side alignment info going into {}", sbs_filename);
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities") << endl;
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities");
for (string col_name: extra_ref_columns) {
myfile << fmt::format("\tref_{0}", col_name);
}
for (string col_name: extra_hyp_columns) {
myfile << fmt::format("\thyp_{0}", col_name);
}
myfile << endl;

// keep track of error groupings
ErrorGroups groups_err;
Expand All @@ -554,8 +561,8 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
for (auto wer_tag: wer_tags) {
tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|";
}
string ref_tk = p_stitch.reftk;
string hyp_tk = p_stitch.hyptk;
string ref_tk = p_stitch.reftk.token;
string hyp_tk = p_stitch.hyptk.token;
string tag = "";

if (ref_tk == NOOP) {
Expand Down Expand Up @@ -587,7 +594,15 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
eff_class = tk_classLabel;
}

myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags) << endl;
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags);

for (string col_name: extra_ref_columns) {
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, true, col_name));
}
for (string col_name: extra_hyp_columns) {
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, false, col_name));
}
myfile << endl;
offset++;
}

Expand Down
2 changes: 1 addition & 1 deletion src/wer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ void CalculatePrecisionRecall(wer_alignment &topAlignment, int threshold);
typedef vector<pair<size_t, string>> ErrorGroups;

void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename);
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns);
void JsonLogUnigramBigramStats(wer_alignment &topAlignment);
Loading