Skip to content

Commit

Permalink
patch to entity treatment of punctuation
Browse files Browse the repository at this point in the history
  • Loading branch information
qmac committed Oct 11, 2023
1 parent 9aba0da commit 1a9d4d4
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
/***********************************
NLP FstLoader class start
************************************/
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
Json::Value wer_sidecar)
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar)
: NlpFstLoader(records, normalization, wer_sidecar, true) {}

NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
Json::Value wer_sidecar, bool processLabels, bool use_punctuation, bool use_case)
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar,
bool processLabels, bool use_punctuation, bool use_case)
: FstLoader() {
mJsonNorm = normalization;
mWerSidecar = wer_sidecar;
Expand All @@ -29,7 +28,6 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
std::string last_label;
bool firstTk = true;


// fuse multiple rows that have the same id/label into one entry only
for (auto &row : records) {
mNlpRows.push_back(row);
Expand All @@ -41,10 +39,10 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma

// Update wer tags in records to real string labels
vector<string> real_wer_tags;
for (auto &tag: curr_row_tags) {
for (auto &tag : curr_row_tags) {
auto real_tag = tag;
if (mWerSidecar != Json::nullValue) {
real_tag = "###"+ real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###";
real_tag = "###" + real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###";
}
real_wer_tags.push_back(real_tag);
}
Expand Down Expand Up @@ -85,15 +83,15 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
}
} else {
if (!mUseCase) {
curr_tk = UnicodeLowercase(curr_tk);
curr_tk = UnicodeLowercase(curr_tk);
}
mToken.push_back(curr_tk);
mSpeakers.push_back(speaker);
}
if (use_punctuation && punctuation != "") {
mToken.push_back(punctuation);
mSpeakers.push_back(speaker);
RawNlpRecord punc_row = row;
RawNlpRecord punc_row;
punc_row.token = punc_row.punctuation;
punc_row.punctuation = "";
mNlpRows.push_back(punc_row);
Expand Down Expand Up @@ -341,19 +339,20 @@ std::vector<RawNlpRecord> NlpReader::read_from_disk(const std::string &filename)
std::vector<RawNlpRecord> vect;
io::CSVReader<13, io::trim_chars<' ', '\t'>, io::no_quote_escape<'|'>> input_nlp(filename);
// token|speaker|ts|endTs|punctuation|prepunctuation|case|tags|wer_tags|ali_comment|oldTs|oldEndTs
input_nlp.read_header(io::ignore_missing_column | io::ignore_extra_column,
"token", "speaker", "ts", "endTs", "punctuation", "prepunctuation",
"case", "tags", "wer_tags", "ali_comment", "oldTs", "oldEndTs", "confidence");

std::string token, speaker, ts, endTs, punctuation, prepunctuation, casing, tags, wer_tags, ali_comment, oldTs, oldEndTs, confidence;
while (input_nlp.read_row(token, speaker, ts, endTs, punctuation, prepunctuation, casing, tags, wer_tags, ali_comment, oldTs,
oldEndTs, confidence)) {
input_nlp.read_header(io::ignore_missing_column | io::ignore_extra_column, "token", "speaker", "ts", "endTs",
"punctuation", "prepunctuation", "case", "tags", "wer_tags", "ali_comment", "oldTs", "oldEndTs",
"confidence");

std::string token, speaker, ts, endTs, punctuation, prepunctuation, casing, tags, wer_tags, ali_comment, oldTs,
oldEndTs, confidence;
while (input_nlp.read_row(token, speaker, ts, endTs, punctuation, prepunctuation, casing, tags, wer_tags, ali_comment,
oldTs, oldEndTs, confidence)) {
RawNlpRecord record;
record.speakerId = speaker;
record.casing = casing;
record.punctuation = punctuation;
if (input_nlp.has_column("prepunctuation")) {
record.prepunctuation = prepunctuation;
record.prepunctuation = prepunctuation;
}
record.ts = ts;
record.endTs = endTs;
Expand All @@ -365,7 +364,7 @@ std::vector<RawNlpRecord> NlpReader::read_from_disk(const std::string &filename)
record.wer_tags = GetWerTags(wer_tags);
}
if (input_nlp.has_column("confidence")) {
record.confidence = confidence;
record.confidence = confidence;
}
vect.push_back(record);
}
Expand Down

0 comments on commit 1a9d4d4

Please sign in to comment.