Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSTALIGN-37: Add flag in fstalign to allow for case-sensitive testing #51

Merged
merged 14 commits into from
Oct 19, 2023
Merged
2 changes: 2 additions & 0 deletions src/FstLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ class FstLoader {
const std::string& wer_sidecar_filename,
const std::string& json_norm_filename,
bool use_punctuation,
bool use_case,
bool symbols_file_included);

static std::unique_ptr<FstLoader> MakeHypothesisLoader(const std::string& hyp_filename,
const std::string& hyp_json_norm_filename,
bool use_punctuation,
bool use_case,
bool symbols_file_included);


Expand Down
24 changes: 16 additions & 8 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
Json::Value wer_sidecar)
: NlpFstLoader(records, normalization, wer_sidecar, true) {}

NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
Json::Value wer_sidecar, bool processLabels, bool use_punctuation)
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
Json::Value wer_sidecar, bool processLabels, bool use_punctuation, bool use_case)
: FstLoader() {
mJsonNorm = normalization;
mWerSidecar = wer_sidecar;
mUseCase = use_case;

std::string last_label;
bool firstTk = true;

Expand Down Expand Up @@ -81,8 +83,10 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
mJsonNorm[curr_label_id]["candidates"][last_idx]["verbalization"].append(curr_tk);
}
} else {
std::string lower_cased = UnicodeLowercase(curr_tk);
mToken.push_back(lower_cased);
if (!mUseCase) {
curr_tk = UnicodeLowercase(curr_tk);
}
mToken.push_back(curr_tk);
mSpeakers.push_back(speaker);
if (use_punctuation && punctuation != "") {
mToken.push_back(punctuation);
Expand Down Expand Up @@ -118,8 +122,10 @@ void NlpFstLoader::addToSymbolTable(fst::SymbolTable &symbol) const {
auto candidate = candidates[i]["verbalization"];
for (auto tk_itr : candidate) {
std::string token = tk_itr.asString();
std::string lower_cased = UnicodeLowercase(token);
AddSymbolIfNeeded(symbol, lower_cased);
if (!mUseCase) {
token = UnicodeLowercase(token);
}
AddSymbolIfNeeded(symbol, token);
}
}
}
Expand Down Expand Up @@ -250,11 +256,13 @@ so we add 2 states
auto candidate = candidates[i]["verbalization"];
for (auto tk_itr : candidate) {
std::string ltoken = std::string(tk_itr.asString());
std::string lower_cased = UnicodeLowercase(ltoken);
if (!mUseCase) {
ltoken = UnicodeLowercase(ltoken);
}
transducer.AddState();
nextState++;

int token_sym = symbol.Find(lower_cased);
int token_sym = symbol.Find(ltoken);
if (token_sym == -1) {
token_sym = symbol.Find(options.symUnk);
}
Expand Down
4 changes: 3 additions & 1 deletion src/Nlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class NlpReader {

class NlpFstLoader : public FstLoader {
public:
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels, bool use_punctuation = false);
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels, bool use_punctuation = false, bool use_case = false);
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar);
virtual ~NlpFstLoader();
virtual void addToSymbolTable(fst::SymbolTable &symbol) const;
Expand All @@ -56,6 +56,8 @@ class NlpFstLoader : public FstLoader {
Json::Value mJsonNorm;
Json::Value mWerSidecar;
virtual const std::string &getToken(int index) const { return mToken.at(index); }
private:
bool mUseCase;
};

#endif /* NLP_H_ */
24 changes: 16 additions & 8 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ wer_alignment Fstalign(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine

vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> hyp_ctm_rows = {},
vector<RawNlpRecord> hyp_nlp_rows = {},
vector<string> one_best_tokens = {}) {
vector<string> one_best_tokens = {},
bool use_case = false) {
auto logger = logger::GetOrCreateLogger("fstalign");

// Go through top alignment and create stitches
Expand Down Expand Up @@ -287,7 +288,11 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h

part.hyp_orig = ctmPart.word;
// sanity check
std::string ctmCopy = UnicodeLowercase(ctmPart.word);
std::string ctmCopy = std::string(ctmPart.word);
if (!use_case) {
ctmCopy = UnicodeLowercase(ctmPart.word);
}

if (hyp_tk != ctmCopy) {
logger->warn(
"hum, looks like the ctm and the alignment got out of sync? [{}] vs "
Expand Down Expand Up @@ -326,7 +331,10 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
part.hyp_orig = token;

// sanity check
std::string token_copy = UnicodeLowercase(token);
std::string token_copy = std::string(token);
if (!use_case) {
token_copy = UnicodeLowercase(token);
}
if (hyp_tk != token_copy) {
logger->warn(
"hum, looks like the text and the alignment got out of sync? [{}] vs "
Expand Down Expand Up @@ -633,7 +641,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
}

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp) {
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case) {
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
// string composition_approach, bool record_case_stats) {
auto logger = logger::GetOrCreateLogger("fstalign");
Expand All @@ -648,19 +656,19 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
NlpFstLoader *nlp_hyp_loader = dynamic_cast<NlpFstLoader *>(&hypLoader);
OneBestFstLoader *best_loader = dynamic_cast<OneBestFstLoader *>(&hypLoader);
if (ctm_hyp_loader) {
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {});
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {}, {}, use_case);
} else if (nlp_hyp_loader) {
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows);
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows, {}, use_case);
} else if (best_loader) {
vector<string> tokens;
tokens.reserve(best_loader->TokensSize());
for (int i = 0; i < best_loader->TokensSize(); i++) {
string token = best_loader->getToken(i);
tokens.push_back(token);
}
stitches = make_stitches(topAlignment, {}, {}, tokens);
stitches = make_stitches(topAlignment, {}, {}, tokens, use_case);
} else {
stitches = make_stitches(topAlignment);
stitches = make_stitches(topAlignment, {}, {}, {}, use_case);
}

NlpFstLoader *nlp_ref_loader = dynamic_cast<NlpFstLoader *>(&refLoader);
Expand Down
2 changes: 1 addition & 1 deletion src/fstalign.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct AlignerOptions {
// int numBests, string symbols_filename, string composition_approach);

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp = false);
AlignerOptions alignerOptions, bool add_inserts_nlp = false, bool use_case = false);
void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file,
AlignerOptions alignerOptions);

Expand Down
14 changes: 9 additions & 5 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ int main(int argc, char **argv) {
int levenstein_maximum_error_streak = 100;
bool record_case_stats = false;
bool use_punctuation = false;
bool use_case = false;
bool disable_approximate_alignment = false;
bool add_inserts_nlp = false;

Expand Down Expand Up @@ -123,6 +124,7 @@ int main(int argc, char **argv) {
"Record precision/recall for how well the hypothesis"
"casing matches the reference.");
get_wer->add_flag("--use-punctuation", use_punctuation, "Treat punctuation from nlp rows as separate tokens");
get_wer->add_flag("--use-case", use_case, "Keeps token casing and considers tokens with different case as different tokens");
get_wer->add_flag("--add-inserts-nlp", add_inserts_nlp, "Add inserts to NLP output");

// CLI11_PARSE(app, argc, argv);
Expand Down Expand Up @@ -154,8 +156,8 @@ int main(int argc, char **argv) {


// loading "reference" inputs
std::unique_ptr<FstLoader> hyp = FstLoader::MakeHypothesisLoader(hyp_filename, hyp_json_norm_filename, use_punctuation, !symbols_filename.empty());
std::unique_ptr<FstLoader> ref = FstLoader::MakeReferenceLoader(ref_filename, wer_sidecar_filename, json_norm_filename, use_punctuation, !symbols_filename.empty());
std::unique_ptr<FstLoader> hyp = FstLoader::MakeHypothesisLoader(hyp_filename, hyp_json_norm_filename, use_punctuation, use_case, !symbols_filename.empty());
std::unique_ptr<FstLoader> ref = FstLoader::MakeReferenceLoader(ref_filename, wer_sidecar_filename, json_norm_filename, use_punctuation, use_case, !symbols_filename.empty());

AlignerOptions alignerOptions;
alignerOptions.speaker_switch_context_size = speaker_switch_context_size;
Expand All @@ -178,7 +180,7 @@ int main(int argc, char **argv) {
}

if (command == "wer") {
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp);
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case);
} else if (command == "align") {
if (output_nlp.empty()) {
console->error("the output nlp file must be specified");
Expand Down Expand Up @@ -219,6 +221,7 @@ std::unique_ptr<FstLoader> FstLoader::MakeReferenceLoader(const std::string& ref
const std::string& wer_sidecar_filename,
const std::string& json_norm_filename,
bool use_punctuation,
bool use_case,
bool symbols_file_included) {
auto console = logger::GetLogger("console");
Json::Value obj;
Expand Down Expand Up @@ -265,7 +268,7 @@ std::unique_ptr<FstLoader> FstLoader::MakeReferenceLoader(const std::string& ref
NlpReader nlpReader = NlpReader();
console->info("reading reference nlp from {}", ref_filename);
auto vec = nlpReader.read_from_disk(ref_filename);
return std::make_unique<NlpFstLoader>(vec, obj, wer_sidecar_obj, true, use_punctuation);
return std::make_unique<NlpFstLoader>(vec, obj, wer_sidecar_obj, true, use_punctuation, use_case);
} else if (EndsWithCaseInsensitive(ref_filename, string(".ctm"))) {
console->info("reading reference ctm from {}", ref_filename);
CtmReader ctmReader = CtmReader();
Expand All @@ -288,6 +291,7 @@ std::unique_ptr<FstLoader> FstLoader::MakeReferenceLoader(const std::string& ref
std::unique_ptr<FstLoader> FstLoader::MakeHypothesisLoader(const std::string& hyp_filename,
const std::string& hyp_json_norm_filename,
bool use_punctuation,
bool use_case,
bool symbols_file_included) {
auto console = logger::GetLogger("console");

Expand Down Expand Up @@ -329,7 +333,7 @@ std::unique_ptr<FstLoader> FstLoader::MakeHypothesisLoader(const std::string& hy
auto vec = nlpReader.read_from_disk(hyp_filename);
// for now, nlp files passed as hypothesis won't have their labels handled as such
// this also mean that json normalization will be ignored
return std::make_unique<NlpFstLoader>(vec, hyp_json_obj, hyp_empty_json, false, use_punctuation);
return std::make_unique<NlpFstLoader>(vec, hyp_json_obj, hyp_empty_json, false, use_punctuation, use_case);
} else if (EndsWithCaseInsensitive(hyp_filename, string(".ctm"))) {
console->info("reading hypothesis ctm from {}", hyp_filename);
CtmReader ctmReader = CtmReader();
Expand Down
2 changes: 1 addition & 1 deletion src/version.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#define FSTALIGNER_VERSION_MAJOR 1
#define FSTALIGNER_VERSION_MINOR 10
#define FSTALIGNER_VERSION_MINOR 11
#define FSTALIGNER_VERSION_PATCH 0
33 changes: 33 additions & 0 deletions test/data/short.aligned.case.nlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
token|speaker|ts|endTs|punctuation|prepunctuation|case|tags|wer_tags|oldTs|oldEndTs|ali_comment|confidence
<crosstalk>|2|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
yeah|1|||,||LC|[]|[]|||del|
right|1|0.0000|0.0000|.||LC|[]|[]||||
Yeah|1|||,||UC|[]|[]|||del|
all|1|||||LC|[]|[]|||del|
right|1|0.0000|0.0000|,||LC|[]|[]|||sub(I'll)|
probably|1|0.0000|0.0000|||LC|[]|[]|||sub(do)|
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|.||LC|[]|[]||||
Are|3|0.0000|0.0000|||UC|[]|[]||||
there|3|0.0000|0.0000|||LC|[]|[]||||
any|3|0.0000|0.0000|||LC|[]|[]||||
visuals|3|0.0000|0.0000|||LC|[]|[]||||
that|3|0.0000|0.0000|||LC|[]|[]||||
come|3|0.0000|0.0000|||LC|[]|[]||||
to|3|0.0000|0.0000|||LC|[]|[]||||
mind|3|0.0000|0.0000|||LC|[]|[]||||
or|3|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
sure|1|0.0000|0.0000|.||LC|[]|[]||||
When|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
hear|1|0.0000|0.0000|||LC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
think|1|0.0000|0.0000|||LC|[]|[]||||
about|1|0.0000|0.0000|||LC|[]|[]||||
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|:||LC|[]|[]||||
foo|1|0.0000|0.0000|||LC|[]|[]|||sub(Foobar)|
a|1|0.0000|0.0000|||LC|[]|[]||||
43 changes: 43 additions & 0 deletions test/data/short.aligned.punc_case.nlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
token|speaker|ts|endTs|punctuation|prepunctuation|case|tags|wer_tags|oldTs|oldEndTs|ali_comment|confidence
<crosstalk>|2|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
,|1|0.0000|0.0000|||UC|[]|[]||||
yeah|1|||,||LC|[]|[]|||del|
,|1|||||LC|[]|[]|||del|
right|1|0.0000|0.0000|.||LC|[]|[]||||
.|1|||||LC|[]|[]|||del|
Yeah|1|||,||UC|[]|[]|||del|
,|1|||||UC|[]|[]|||del|
all|1|||||LC|[]|[]|||del|
right|1|||,||LC|[]|[]|||del|
,|1|0.0000|0.0000|||LC|[]|[]|||sub(I'll)|
probably|1|0.0000|0.0000|||LC|[]|[]|||sub(do)|
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|.||LC|[]|[]||||
.|1|0.0000|0.0000|||LC|[]|[]|||sub(?)|
Are|3|0.0000|0.0000|||UC|[]|[]||||
there|3|0.0000|0.0000|||LC|[]|[]||||
any|3|0.0000|0.0000|||LC|[]|[]||||
visuals|3|0.0000|0.0000|||LC|[]|[]||||
that|3|0.0000|0.0000|||LC|[]|[]||||
come|3|0.0000|0.0000|||LC|[]|[]||||
to|3|0.0000|0.0000|||LC|[]|[]||||
mind|3|0.0000|0.0000|||LC|[]|[]||||
or|3|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
,|1|0.0000|0.0000|||UC|[]|[]||||
sure|1|0.0000|0.0000|.||LC|[]|[]||||
.|1|0.0000|0.0000|||LC|[]|[]||||
When|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
hear|1|0.0000|0.0000|||LC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
,|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
think|1|0.0000|0.0000|||LC|[]|[]||||
about|1|0.0000|0.0000|||LC|[]|[]||||
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|:||LC|[]|[]||||
:|1|0.0000|0.0000|||LC|[]|[]||||
foo|1|0.0000|0.0000|||LC|[]|[]|||sub(,)|
a|1|0.0000|0.0000|||LC|[]|[]||||
20 changes: 20 additions & 0 deletions test/fstalign_Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,26 @@ TEST_CASE_METHOD(UniqueTestsFixture, "main-adapted-composition()") {
REQUIRE_THAT(result, Contains("WER: INS:2 DEL:7 SUB:4"));
}

SECTION("wer with case(nlp output)") {
const auto result =
exec(command("wer", approach, "short_punc.ref.nlp", "short_punc.hyp.nlp", sbs_output, nlp_output, TEST_SYNONYMS)+" --use-case");
const auto testFile = std::string{TEST_DATA} + "short.aligned.case.nlp";

REQUIRE(compareFiles(nlp_output.c_str(), testFile.c_str()));
REQUIRE_THAT(result, Contains("WER: 6/32 = 0.1875"));
REQUIRE_THAT(result, Contains("WER: INS:0 DEL:3 SUB:3"));
}

SECTION("wer with case and punctuation(nlp output)") {
const auto result =
exec(command("wer", approach, "short_punc.ref.nlp", "short_punc.hyp.nlp", sbs_output, nlp_output, TEST_SYNONYMS)+" --use-punctuation --use-case");
const auto testFile = std::string{TEST_DATA} + "short.aligned.punc_case.nlp";

REQUIRE(compareFiles(nlp_output.c_str(), testFile.c_str()));
REQUIRE_THAT(result, Contains("WER: 13/42 = 0.3095"));
REQUIRE_THAT(result, Contains("WER: INS:2 DEL:7 SUB:4"));
}

// alignment tests

SECTION("align_1") {
Expand Down
Loading