Skip to content

Commit

Permalink
adding basic fast-leveinstein distance computation (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
jprobichaud authored Nov 24, 2021
1 parent 21f24ed commit 9e33c0e
Show file tree
Hide file tree
Showing 24 changed files with 989 additions and 84 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ endif()
add_library(fstaligner-common
src/fstalign.cpp
src/wer.cpp
src/fast-d.cpp
src/AdaptedComposition.cpp
src/StandardComposition.cpp
src/AlignmentTraversor.cpp
Expand Down
99 changes: 71 additions & 28 deletions src/AdaptedComposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ AdaptedCompositionFst::AdaptedCompositionFst(const fst::StdFst &fstA, const fst:
: fstA_{fstA}, fstB_{fstB}, symbols_{NULL} {
logger_ = logger::GetOrCreateLogger("AdaptedCompositionFst");
logger_->set_level(spdlog::level::info);
#if TRACE
logger_->set_level(spdlog::level::trace);
#endif
}

AdaptedCompositionFst::AdaptedCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB, SymbolTable &symbols)
: fstA_{fstA}, fstB_{fstB} {
logger_ = logger::GetOrCreateLogger("AdaptedCompositionFst");
logger_->set_level(spdlog::level::info);
#if TRACE
logger_->set_level(spdlog::level::trace);
#endif
SetSymbols(&symbols);

FstAlignOption options;
Expand Down Expand Up @@ -237,6 +243,8 @@ bool AdaptedCompositionFst::TryGetArcsAtState(StateId fromStateId, vector<fst::S
int num_match = 0;
int num_entity = 0;

int arc_added = 0;

for (ArcIterator<StdFst> aiter(fstA_, refA); !aiter.Done(); aiter.Next()) {
const fst::StdArc &arcA = aiter.Value();

Expand All @@ -256,6 +264,7 @@ bool AdaptedCompositionFst::TryGetArcsAtState(StateId fromStateId, vector<fst::S
if (arcA.olabel == 0) {
StateId skip_eps_state_ref_id = GetOrCreateComposedState(arcA.nextstate, refB);
out_vector->push_back(StdArc(0, 0, 0.0, skip_eps_state_ref_id));
arc_added++;
continue;
}

Expand All @@ -279,6 +288,7 @@ bool AdaptedCompositionFst::TryGetArcsAtState(StateId fromStateId, vector<fst::S
// let's keep the weight to 0, this isn't an error
StateId del_state_ref_id = GetOrCreateComposedState(arcA.nextstate, refB);
out_vector->push_back(StdArc(arcA.ilabel, del_label_id_, 0.0, del_state_ref_id));
arc_added++;
} else {
// skipping this path already since we can't reach the end of it without deletions or insertions or
// substitutions
Expand All @@ -291,65 +301,98 @@ bool AdaptedCompositionFst::TryGetArcsAtState(StateId fromStateId, vector<fst::S
continue;
}

float weightA = arcA.weight.Value();

for (ArcIterator<StdFst> aiterB(fstB_, refB); !aiterB.Done(); aiterB.Next()) {
const fst::StdArc &arcB = aiterB.Value();

float weightB = arcB.weight.Value();
bool arcs_matched = false;
#if TRACE
logger_->trace("{}/{} >] word-B {} has a weight {}", dbg_count, here_snap, symbols_->Find(arcB.ilabel), weightB);
logger_->trace("{}/{} >] for {}/{} vs {}/{}, we have num_match {} and num_entity {}", dbg_count, here_snap,
arcA.olabel, symbols_->Find(arcA.olabel), arcB.ilabel, symbols_->Find(arcB.ilabel), num_match,
num_entity);
#endif

// we have a matching label
if (arcA.olabel == arcB.ilabel) {
num_match++;
arcs_matched = true;

StateId c = GetOrCreateComposedState(arcA.nextstate, arcB.nextstate);
#if TRACE
logger_->trace("{}/{} >] adding cor/{}/{} to {}, num_match = {}", dbg_count, here_snap, arcB.olabel,
symbols_->Find(arcB.olabel), c, num_match);
#endif
out_vector->push_back(StdArc(arcA.ilabel, arcB.olabel, 0.0, c));
arc_added++;
}

if (TRACE) {
logger_->trace("{}/{} >] for {}/{} vs {}/{}, we have num_match {} and num_entity {}", dbg_count, here_snap,
arcA.olabel, symbols_->Find(arcA.olabel), arcB.ilabel, symbols_->Find(arcB.ilabel), num_match,
num_entity);
}

if (num_match == 0) {
// if (num_match == 0 || num_match == num_entity) {
// this could be an insertion, this could be a substitution
// TODO: we can be more clever here
// When the WER of a section is high, greedy matches can lead to a dead-end.
// if (!arcs_matched && weightB <= 0) {
if (weightB <= 0) {
// B can be inserted...
StateId ins_state_ref_id = GetOrCreateComposedState(refA, arcB.nextstate);
StateId sub_state_ref_id = GetOrCreateComposedState(arcA.nextstate, arcB.nextstate);
#if TRACE
logger_->trace("{}/{} adding ins/{}/{}", dbg_count, here_snap, arcB.olabel, symbols_->Find(arcB.olabel));
logger_->trace("{}/{} adding sub/{}/{}", dbg_count, here_snap, arcA.ilabel, arcB.olabel);
logger_->trace("{}/{} >] adding ins/{}/{}", dbg_count, here_snap, arcB.olabel, symbols_->Find(arcB.olabel));
#endif
// out_vector->push_back(StdArc(ins_label_id, arcB.olabel, insertion_cost, ins_state_ref_id));
out_vector->push_back(StdArc(0, arcB.olabel, insertion_cost, ins_state_ref_id));
out_vector->push_back(StdArc(arcA.ilabel, arcB.olabel, substitution_cost, sub_state_ref_id));
} else {
arc_added++;
}

// When the WER of a section is high, greedy matches can lead to a dead-end.
// if (!arcs_matched && weightA <= 0 && weightB <= 0) {
if (weightA <= 0 && weightB <= 0) {
// allow sub
StateId sub_state_ref_id = GetOrCreateComposedState(arcA.nextstate, arcB.nextstate);
#if TRACE
logger_->trace("a label match was found, not putting ins/sub arcs");
logger_->trace("{}/{} >] adding sub/{}/{}", dbg_count, here_snap, arcA.ilabel, arcB.olabel);
#endif
out_vector->push_back(StdArc(arcA.ilabel, arcB.olabel, substitution_cost, sub_state_ref_id));
arc_added++;
}
}

if (num_match == 0) {
// if (num_match == 0 || num_match == num_entity) {
if (num_match == 0 || weightA < 0) {
// let's add a potential deletion
// TODO: we can be more clever here
// out_vector->push_back(StdArc(arcA.ilabel, del_label_id, deletion_cost, del_state_ref_id));
//
// When the WER of a section is high, greedy matches can lead to a dead-end.
// if (weightA > 0) {
// // we have a ref arc that /must/ matched, but didn't, skipping deletion.
// continue;
// }
StateId del_state_ref_id = GetOrCreateComposedState(arcA.nextstate, refB);
out_vector->push_back(StdArc(arcA.ilabel, 0, deletion_cost, del_state_ref_id));
if (TRACE) {
logger_->trace("{}/{} adding del/{}/{}", dbg_count, here_snap, arcA.ilabel, symbols_->Find(arcA.ilabel));
}
arc_added++;
#if TRACE
logger_->trace("{}/{} >] adding del/{}/{}", dbg_count, here_snap, arcA.ilabel, symbols_->Find(arcA.ilabel));
#endif
} else {
if (TRACE) {
logger_->trace("a label match was found, not putting del arc");
}
#if TRACE
logger_->trace("{}/{} >] >]a label match was found, not putting del arc", dbg_count, here_snap);
#endif
}
}

if (fstA_.NumArcs(refA) == 0) {
// we reached the end of the A graph, but what about B?
#if TRACE
logger_->trace("{}/{} >] end of graph A found at {}", dbg_count, here_snap, refA);
#endif

for (ArcIterator<StdFst> aiterB(fstB_, refB); !aiterB.Done(); aiterB.Next()) {
const fst::StdArc &arcB = aiterB.Value();
arc_added++;
float weightB = arcB.weight.Value();
if (weightB > 0) {
continue;
}

#if TRACE
logger_->trace("{}/{} >] adding ins/{}/{}", dbg_count, here_snap, arcB.olabel, symbols_->Find(arcB.olabel));
#endif
StateId ins_state_ref_id = GetOrCreateComposedState(refA, arcB.nextstate);
// out_vector->push_back(StdArc(ins_label_id, arcB.olabel, insertion_cost, ins_state_ref_id));
out_vector->push_back(StdArc(0, arcB.olabel, insertion_cost, ins_state_ref_id));
Expand All @@ -368,7 +411,7 @@ void AdaptedCompositionFst::SetSymbols(fst::SymbolTable *symbols) {
entity_label_ids.resize(symbols->NumSymbols(), false);

// mostly for optimization purpose
logger_->info("{}:{} we created 2 vector<bool> of {} items", __FILE__, __LINE__, symbols->NumSymbols());
logger_->debug("{}:{} we created 2 vector<bool> of {} items", __FILE__, __LINE__, symbols->NumSymbols());

for (SymbolTableIterator siter(*symbols); !siter.Done(); siter.Next()) {
int64 sid = siter.Value();
Expand Down Expand Up @@ -448,4 +491,4 @@ void AdaptedCompositionFst::DebugComposedGraph() {
copy_fst.Write(outfile, wopts);
}
} // end of debug code
}
}
33 changes: 31 additions & 2 deletions src/Ctm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void CtmFstLoader::addToSymbolTable(SymbolTable &symbol) const {
}
}

StdVectorFst CtmFstLoader::convertToFst(const SymbolTable &symbol) const {
StdVectorFst CtmFstLoader::convertToFst(const SymbolTable &symbol, std::vector<int> map) const {
auto logger = logger::GetOrCreateLogger("ctmloader");
//
StdVectorFst transducer;
Expand All @@ -48,20 +48,49 @@ StdVectorFst CtmFstLoader::convertToFst(const SymbolTable &symbol) const {

int prevState = 0;
int nextState = 1;
int wc = 0;
int map_sz = map.size();
for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) {
std::string token = *i;
std::transform(token.begin(), token.end(), token.begin(), ::tolower);
transducer.AddState();

transducer.AddArc(prevState, StdArc(symbol.Find(token), symbol.Find(token), 0.0f, nextState));
if (map_sz > wc && map[wc] > 0) {
transducer.AddArc(prevState, StdArc(symbol.Find(token), symbol.Find(token), 1.0f, nextState));
} else {
transducer.AddArc(prevState, StdArc(symbol.Find(token), symbol.Find(token), 0.0f, nextState));
}

prevState = nextState;
nextState++;
wc++;
}

transducer.SetFinal(prevState, 0.0f);
return transducer;
}

std::vector<int> CtmFstLoader::convertToIntVector(fst::SymbolTable &symbol) const {
auto logger = logger::GetOrCreateLogger("ctmloader");
std::vector<int> vect;
addToSymbolTable(symbol);
int sz = mToken.size();
logger->info("creating std::vector<int> for CTM for {} tokens", sz);
vect.reserve(sz);

FstAlignOption options;
for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) {
std::string token = *i;
int token_sym = symbol.Find(token);
if (token_sym == -1) {
token_sym = symbol.Find(options.symUnk);
}
vect.emplace_back(token_sym);
}

return vect;
}

/***************************************
CTM FST Loader Class End
***************************************/
Expand Down
3 changes: 2 additions & 1 deletion src/Ctm.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class CtmFstLoader : public FstLoader {
~CtmFstLoader();
vector<RawCtmRecord> mCtmRows;
virtual void addToSymbolTable(fst::SymbolTable &symbol) const;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol) const;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const;
virtual std::vector<int> convertToIntVector(fst::SymbolTable &symbol) const;
virtual const std::string &getToken(int index) const { return mToken.at(index); }
};

Expand Down
15 changes: 12 additions & 3 deletions src/FstFileLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@ FstFileLoader::FstFileLoader(std::string filename) : FstLoader(), filename_(file

void FstFileLoader::addToSymbolTable(fst::SymbolTable& symbol) const { return; }

fst::StdVectorFst FstFileLoader::convertToFst(const fst::SymbolTable& symbol) const {
fst::StdVectorFst FstFileLoader::convertToFst(const fst::SymbolTable& symbol, std::vector<int> map) const {
auto logger = logger::GetOrCreateLogger("FstFileLoader");
fst::StdVectorFst *transducer = fst::StdVectorFst::Read(filename_);
fst::StdVectorFst* transducer = fst::StdVectorFst::Read(filename_);
logger->info("Total FST has {} states.", transducer->NumStates());
return (*transducer);
return (*transducer);
}

std::vector<int> FstFileLoader::convertToIntVector(fst::SymbolTable& symbol) const {
auto logger = logger::GetOrCreateLogger("FstFileLoader");
std::vector<int> vect;
logger->error("convertToIntVector isn't implemented for FST inputs");
vect.reserve(0);
vect.resize(0);
return vect;
}

FstFileLoader::~FstFileLoader() {}
4 changes: 3 additions & 1 deletion src/FstFileLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <fstream>
#include <stdexcept>
#include <vector>

#include "FstLoader.h"
#include "utilities.h"
Expand All @@ -22,7 +23,8 @@ class FstFileLoader : public FstLoader {
~FstFileLoader();

virtual void addToSymbolTable(fst::SymbolTable &symbol) const;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol) const;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const;
virtual std::vector<int> convertToIntVector(fst::SymbolTable &symbol) const;

private:
std::string filename_;
Expand Down
4 changes: 3 additions & 1 deletion src/FstLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ FstLoader.h
#ifndef __FSTLOADER_H_
#define __FSTLOADER_H_

#include <vector>
#include "utilities.h"

class FstLoader {
Expand All @@ -20,7 +21,8 @@ class FstLoader {
virtual ~FstLoader();
virtual void addToSymbolTable(fst::SymbolTable &symbol) const = 0;
static void AddSymbolIfNeeded(fst::SymbolTable &symbol, std::string str_value);
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol) const = 0;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const = 0;
virtual std::vector<int> convertToIntVector(fst::SymbolTable &symbol) const = 0;
};

#endif /* __FSTLOADER_H_ */
34 changes: 32 additions & 2 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,29 @@ void NlpFstLoader::addToSymbolTable(fst::SymbolTable &symbol) const {
}
}

fst::StdVectorFst NlpFstLoader::convertToFst(const fst::SymbolTable &symbol) const {
std::vector<int> NlpFstLoader::convertToIntVector(fst::SymbolTable &symbol) const {
auto logger = logger::GetOrCreateLogger("NlpFstLoader");
std::vector<int> vect;
logger->info("convertToIntVector() Building a std::vector<int> from NLP rows");
addToSymbolTable(symbol);
int sz = mToken.size();
vect.reserve(sz);

FstAlignOption options;
for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) {
std::string token = *i;
int token_sym = symbol.Find(token);
if (token_sym == -1) {
token_sym = symbol.Find(options.symUnk);
}
vect.emplace_back(token_sym);
}

return vect;
// return std::move(vect);
}

fst::StdVectorFst NlpFstLoader::convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const {
auto logger = logger::GetOrCreateLogger("NlpFstLoader");
fst::StdVectorFst transducer;

Expand All @@ -141,6 +163,8 @@ fst::StdVectorFst NlpFstLoader::convertToFst(const fst::SymbolTable &symbol) con

int prevState = 0;
int nextState = 1;
int map_sz = map.size();
int wc = 0;
for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) {
transducer.AddState();

Expand All @@ -150,7 +174,13 @@ fst::StdVectorFst NlpFstLoader::convertToFst(const fst::SymbolTable &symbol) con
token_sym = symbol.Find(options.symUnk);
}

transducer.AddArc(prevState, fst::StdArc(token_sym, token_sym, 0.0f, nextState));
// logger->info("wc {}, token {}, map[wc] = {}, map_sz = {}", wc, token, map[wc], map_sz);
if (map_sz > wc && map[wc] > 0) {
transducer.AddArc(prevState, fst::StdArc(token_sym, token_sym, 1.0f, nextState));
} else {
transducer.AddArc(prevState, fst::StdArc(token_sym, token_sym, 0.0f, nextState));
}
wc++;

if (isEntityLabel(token)) {
/*
Expand Down
4 changes: 3 additions & 1 deletion src/Nlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ class NlpFstLoader : public FstLoader {
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization);
virtual ~NlpFstLoader();
virtual void addToSymbolTable(fst::SymbolTable &symbol) const;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol) const;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const;
virtual std::vector<int> convertToIntVector(fst::SymbolTable &symbol) const;

int GetProperSymbolId(const fst::SymbolTable &symbol, string token, string symUnk) const;
vector<RawNlpRecord> mNlpRows;
vector<std::string> mSpeakers;
Expand Down
Loading

0 comments on commit 9e33c0e

Please sign in to comment.