diff --git a/tools/README.md b/tools/README.md index b8b097c..bab1d31 100644 --- a/tools/README.md +++ b/tools/README.md @@ -25,3 +25,40 @@ A simple bash script that is meant for benchmarking the resource (RAM and runtim Example usage: `bash gather_runtime_metrics.sh output_for_this_release.csv` + +## sbs2fst.py +A python interface to simplify the conversion of a side-by-side file, generated from fstalign's `--output-sbs` flag, into [files that can be used to produce an FST using OpenFST](https://www.openfst.org/twiki/bin/view/FST/FstQuickTour). + +Example usage: + +`python sbs2fst.py sbs_file.txt fst_file_name` + +The output will be two files: `fst_file_name.fst` which will describe the FST in the AT&T FSM format used by OpenFST, and `fst_file_name.txt` which contains the complete list of symbols in the FST. + +The additional flags can be passed into the python script to add metadata that fstalign uses for tracking performance. These are useful to understand when fstalign picks tokens that are: only in the side-by-side's `ref_token` column (labeled by the `--left` flag), only in the side-by-side's `hyp_token` column (labeled by the `--right` flag), or in both columns because the `ref_token` and `hyp_token` agree (labeled by the `--gold` flag). + +Example usage: + +`python sbs2fst.py --tag --left VERBATIM --right NONVERBATIM --gold AGREEMENT sbs_file.txt fst_file_name` + +The output will produce an FST with tags indicating tokens that were only in the `ref_token` with `VERBATIM`, tokens that were only in the `hyp_token` with `NONVERBATIM`, and tokens that were in both columns with `AGREEMENT`. + +### Compiling the FST +Once you have used `sbs2fst.py` to produce the `.txt` and `.fst` files, you *must* then compile the FST before passing it into fstalign. An example command can be found below: + +`fstcompile --isymbols=${SYMBOLS} --osymbols=${SYMBOLS} ${TXT_FST} ${COMPILED_FST}` + +where `SYMBOLS` is the `.txt` file produced by `sbs2fst.py`, `TXT_FST` is the `.fst` file, and `COMPILED_FST` is a new `.fst` file that produces the binary FST usable by fstalign. + +Example usage: +```bash +python sbs2fst.py --tag --left VERBATIM --right NONVERBATIM --gold AGREEMENT sbs_file.txt fst_file_name +fstcompile --isymbols=fst_file_name.txt --osymbols=fst_file_name.txt fst_file_name.fst fst_file_name.compiled.fst +``` +You can now use `fst_file_name.compiled.fst` in fstalign with the corresponding symbols file as follows: +```bash +fstalign --ref fst_file_name.complied.fst --symbols fst_file_name.txt ... +``` + +Note that when you `sbs2fst.py` to produce a "tagged" FST with the `--tag` flag, fstalign will aggregate WER metrics for each of the specified tags (`--left`, `--right`, and `--gold`) in the JSON log file specified by fstalign's `--json-log` flag. + diff --git a/tools/sbs2fst.py b/tools/sbs2fst.py new file mode 100644 index 0000000..32e4ed1 --- /dev/null +++ b/tools/sbs2fst.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2023 +# Author: Miguel Ángel del Río Fernández +# All Rights Reserved + +from argparse import ArgumentParser +from collections import OrderedDict +from dataclasses import dataclass, field +from itertools import takewhile +from pathlib import Path +from typing import Dict, Generator, List, Optional, Tuple + + +@dataclass +class SbsEntry: + """ Represent a single SBS line.""" + ref_word: str + hyp_word: str + error: bool + entity_class: str + wer_tags: List[str] = field(default_factory=list) + extra_columns: List[str] = field(default_factory=list) + + @classmethod + def from_line(cls, line: str) -> 'SbsEntry': + parts = line.strip(' \n').split('\t') + if len(parts) == 4: + # old format + entry = SbsEntry(parts[0].strip(), parts[1].strip(), + parts[2] == 'ERR', parts[3]) + elif len(parts) == 5: + # New format, wer_tags + entry = SbsEntry(parts[0].strip(), parts[1].strip(), + parts[2] == 'ERR', parts[3], + [tag for tag in parts[4].split('|') if tag]) + elif len(parts) > 5: + entry = SbsEntry(parts[0].strip(), parts[1].strip(), + parts[2] == 'ERR', parts[3], + [tag for tag in parts[4].split('|') if tag], + extra_columns=parts[5:]) + else: + raise RuntimeError(f"Could not parse the line as SBS:\n{line}") + return entry + + def __str__(self): + if self.error: + err_str = "ERR" + else: + err_str = "" + wer_tags_str = "|".join(self.wer_tags) + if wer_tags_str: + wer_tags_str += "|" + return '\t'.join([self.ref_word, self.hyp_word, err_str, + self.entity_class, wer_tags_str]+self.extra_columns) + + +def load_from_file(fp: Path) -> Generator[SbsEntry, None, None]: + with open(fp) as f: + f.readline() + lines = takewhile(lambda x: not x.startswith("--------"), f.readlines()) + for line in lines: + yield SbsEntry.from_line(line) + + +class FSTState: + def __init__(self): + self.state: int = 0 + self.vocabulary: OrderedDict = OrderedDict({"": 0}) + self.unique_id: int = 0 + + def update_vocabulary(self, word): + """If `word` isn't in `self.vocabulary` add it + with it's own unique id.""" + if word not in self.vocabulary: + self.vocabulary[word] = len(self.vocabulary) + + def get_uid(self): + """Return the `self.unique_id` and increment it + by one.""" + current_uid = self.unique_id + self.unique_id += 1 + return current_uid + + +def init_args(): + parser = ArgumentParser(description="SBS to FST") + parser.add_argument("sbs_file", type=Path, help="The input SBS file") + parser.add_argument("fst_file", type=Path, help="The output FST file") + parser.add_argument( + "--left", + type=str, + default="LEFT", + help="Label for the left column. This label will be given to " + "words that occur on the left (reference) side of the SBS " + "during an ERR.", + ) + parser.add_argument( + "--right", + type=str, + default="RIGHT", + help="Label for the right column. This label will be given to " + "words that occur on the right (hypothesis) side of the SBS " + "during an ERR.", + ) + parser.add_argument( + "--gold", + type=str, + default="GOLD", + help="Label for the gold column. This is for words that both " + "transcripts agree upon in the SBS.", + ) + parser.add_argument( + "--tag", + action="store_true", + help="If set, the script will add extra tagging information", + ) + return parser.parse_args() + + +def prepare_IO( + input: Path, + output: Path, +): + """Determines if the input is a directory or file and prepares output accordingly""" + input_files = [] + output_files = [] + if input.is_dir(): + output.mkdir(parents=True, exist_ok=True) + for file in input.glob("**/*.txt"): + input_files.append(file) + output_files.append(output / file.stem) + else: + input_files = [input] + output_files = [output] + return input_files, output_files + + +def _to_fst_line(state1, state2, arc, weight: float=0): + return f"{state1} {state2} {arc} {arc} {weight}" + + +def flush_span( + span: List[str], state: int, *, tag: Optional[str] = None, branch_factor: int = 0 +) -> Tuple[List[str], int]: + """Flush the span by generating the relevant fst lines. If `tag` + is set add surrounding fst lines to correspond to the tag. + `branch_factor` can also be set to increase the initial transition + from the tag state to the first span state (SHOULD ONLY BE USED IN + COMBINATION WITH `tag`). + The primary use of the `branch_factor` is for the right side during + a disagreement -- you want the first right-side arc to go from the same + start as the left-side to a new state that isn't used by the left side + at all. So by specifying the `branch_factor` you can "skip" states. + In the context of a disagreement, the left-side will have 0 `branch_factor` + while the right-side must have a `branch_factor` the size of left-side + length. + """ + if len(span) == 0: + return [], state + + span_state = state + branch_factor + 1 + if tag: + fst_lines = [_to_fst_line(state, span_state, tag)] + else: + fst_lines = [_to_fst_line(state, span_state, span[0])] + span = span[1:] + + for token in span: + fst_lines.append(_to_fst_line(span_state, span_state + 1, token)) + span_state += 1 + + if tag: + fst_lines.append(_to_fst_line(span_state, span_state + 1, tag)) + span_state += 1 + return fst_lines, span_state + + +def agreement_flush( + gold_span: List[str], fst_state: FSTState, *, tag: bool = False, gold: Optional[str] = None +) -> List[str]: + """Flush "gold" spans when both sides of the sbs agree and update the FSTState. + If `tag` is True, adds a unique tag around the span using `gold` to label. + """ + gold_tag = None + if tag: + gold_tag = f"___MULTIREF:{fst_state.get_uid()}_{gold}___" + fst_state.update_vocabulary(gold_tag) + + gold_fst_lines, new_state = flush_span(gold_span, fst_state.state, tag=gold_tag) + fst_state.state = new_state + + return gold_fst_lines + + +def disagreement_flush( + left_span: List[str], + right_span: List[str], + fst_state: FSTState, + *, + tag: bool = False, + left: Optional[str] = None, + right: Optional[str] = None, +) -> List[str]: + """Flush the left and right spans when transcripts disagree and update the FSTState. + If `tag` is True, adds a unique tag around the left span using `left` to label and + around the right span using `right` to label. + """ + fst_lines = [] + + left_tag = None + if tag: + left_tag = f"___MULTIREF:{fst_state.get_uid()}_{left}___" + fst_state.update_vocabulary(left_tag) + + left_fst_lines, left_end_state = flush_span(left_span, fst_state.state, tag=left_tag) + fst_lines.extend(left_fst_lines) + + right_tag = None + if tag: + right_tag = f"___MULTIREF:{fst_state.get_uid()}_{right}___" + fst_state.update_vocabulary(right_tag) + + right_fst_lines, right_end_state = flush_span( + right_span, fst_state.state, tag=right_tag, branch_factor=len(left_fst_lines) + ) + fst_lines.extend(right_fst_lines) + + max_end_state = max(left_end_state, right_end_state) + # We have to return both paths back to same state to progress + fst_lines.append(_to_fst_line(left_end_state, max_end_state + 1, "")) + fst_lines.append(_to_fst_line(right_end_state, max_end_state + 1, "")) + + fst_state.state = max_end_state + 1 + + return fst_lines + + +def sbs2fst( + sbs_file: Path, + *, + tag: bool = False, + gold: Optional[str] = None, + left: Optional[str] = None, + right: Optional[str] = None, +) -> Tuple[List[str], Dict[str, int]]: + """Given an `sbs_file` create the equivalent fst object. + Optionally you can include tags by setting `tag` to true. These will be distinguished by the + tag labels provided in `gold` (agreements), `left` (words on the reference side of the sbs not in hypothesis), + and `right` (words on the hypothesis side of the sbs not in reference). + """ + sbs = load_from_file(sbs_file) + + fst_state = FSTState() + fst_lines = [] + + left_span = [] + right_span = [] + gold_span = [] + for row_idx, row in enumerate(sbs): + ref_word = "" if row.ref_word == "" else row.ref_word + hyp_word = "" if row.hyp_word == "" else row.hyp_word + + fst_state.update_vocabulary(ref_word) + fst_state.update_vocabulary(hyp_word) + + if row.ref_word == row.hyp_word: + # First flush the left & right spans to empty any disagreements + if len(left_span) > 0 or len(right_span) > 0: + disagreement_fst_lines = disagreement_flush( + left_span, right_span, fst_state, tag=tag, left=left, right=right + ) + fst_lines.extend(disagreement_fst_lines) + left_span = [] + right_span = [] + gold_span.append(row.ref_word) + else: + # First flush the gold span to empty any agreements + if len(gold_span) > 0: + gold_fst_lines = agreement_flush(gold_span, fst_state, tag=tag, gold=gold) + fst_lines.extend(gold_fst_lines) + gold_span = [] + + if ref_word != "": + left_span.append(ref_word) + if hyp_word != "": + right_span.append(hyp_word) + + # Flush the spans that have infomration. It'll only be a gold or a disagreement. Not both + if len(gold_span) > 0: + gold_fst_lines = agreement_flush(gold_span, fst_state, tag=tag, gold=gold) + fst_lines.extend(gold_fst_lines) + elif len(left_span) > 0 or len(right_span) > 0: + disagreement_fst_lines = disagreement_flush( + left_span, right_span, fst_state, tag=tag, left=left, right=right + ) + fst_lines.extend(disagreement_fst_lines) + + fst_lines.append(f"{fst_state.state}") + + return fst_lines, fst_state.vocabulary + + +def main( + sbs_file: Path, + fst_file: Path, + tag: bool = False, + gold: Optional[str] = None, + left: Optional[str] = None, + right: Optional[str] = None, +): + for inpath, outpath in zip(*prepare_IO(sbs_file, fst_file)): + fst_lines, vocabulary = sbs2fst(inpath, tag=tag, gold=gold, left=left, right=right) + + with open(f"{outpath}.fst", "w") as fstfile: + fstfile.write("\n".join(fst_lines)) + + with open(f"{outpath}.txt", "w") as fstfile: + for key, value in vocabulary.items(): + fstfile.write(f"{key} {value}\n") + + +if __name__ == "__main__": + args = init_args() + main(**vars(args))