diff --git a/dphon/cli.py b/dphon/cli.py index 3b9e3fc..9f2f736 100644 --- a/dphon/cli.py +++ b/dphon/cli.py @@ -22,10 +22,10 @@ Set input format. Currently, plaintext (txt) and JSON lines (jsonl) are supported. - -o , --output-format [default: txt] - Set output format. Currently, plaintext (txt), JSON lines (jsonl), - comma-separated values (csv), and html (html) are supported. Note that - you still need to redirect to a file in order to save the output. + -o , --output-file + Set output filename. By default, output is sent to the terminal. The + output format is determined by the file extension. Supported formats + are JSON lines (jsonl), CSV (csv), and HTML (html). Matching Options: -n , --ngram-order [default: 4] @@ -92,11 +92,10 @@ import csv import logging import os -import sys import time from itertools import combinations from pathlib import Path -from typing import Dict +from typing import Dict, List import jsonlines import pkg_resources @@ -104,6 +103,7 @@ from docopt import docopt from rich import traceback from rich.logging import RichHandler +from rich.padding import Padding from rich.progress import BarColumn, Progress, SpinnerColumn from spacy.language import Language from spacy.tokens import Doc @@ -115,7 +115,7 @@ from .extend import LevenshteinPhoneticExtender from .g2p import get_sound_table_json from .match import Match -from .reuse import MatchGraph +from .reuse import MatchGraph, MatchGroup # Available log levels: default is WARN, -v is INFO, -vv is DEBUG LOG_LEVELS = { @@ -149,36 +149,45 @@ def run() -> None: # process all texts graph = process(nlp, args) - results = list(graph.matches) - logging.info(f"{len(results)} total results matching query") - # sort results by highest total score - results = sorted(results, key=lambda m: m.weighted_score, reverse=True) + # check if we're outputting to a file and find out the format + if args["--output-file"]: + output_path = Path(args["--output-file"]) + output_format = output_path.suffix.lstrip(".").lower() + if output_format not in ["jsonl", "csv", "html"]: + raise ValueError(f"unsupported output format: {output_format}") + + # if requested output match groups, otherwise output matches + results: List[MatchGroup] | List[Match] = [] + if args["--group"]: + results = graph.groups + else: + results = list(graph.matches) + + # sort results by highest weighted score + results = sorted(results, key=lambda result: result.weighted_score, reverse=True) # output depending on provided option - if args["--output-format"] == "jsonl": - with jsonlines.Writer(sys.stdout) as writer: - for match in results: - writer.write(match.as_dict()) - elif args["--output-format"] == "csv": + if output_format == "jsonl": + with jsonlines.Writer(output_path) as writer: + for result in results: + writer.write(result.as_dict()) + elif output_format == "csv": fieldnames = Match("", "", "", "").as_dict().keys() - writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames) + writer = csv.DictWriter(output_path, fieldnames=fieldnames) writer.writeheader() - for match in results: - writer.writerow(match.as_dict()) - elif args["--output-format"] == "html": + for result in results: + writer.writerow(result.as_dict()) + elif output_format == "html": console.record = True - with console.capture(): - for doc in graph.docs: - for group in doc._.groups: - console.print(group) - sys.stdout.write(console.export_html()) + for result in results: + console.print(Padding(result, (0, 0, 1, 0))) + console.save_html(output_path) else: # use system pager by default; colorize if LESS=R with console.pager(styles=os.getenv("LESS", "") == "R"): - for doc in graph.docs: - for group in doc._.groups: - console.print(group) + for result in results: + console.print(Padding(result, (0, 0, 1, 0))) def setup(args: Dict) -> Language: diff --git a/dphon/reuse.py b/dphon/reuse.py index 4d66165..1194206 100644 --- a/dphon/reuse.py +++ b/dphon/reuse.py @@ -36,11 +36,18 @@ def __rich_console__( self, console: Console, options: ConsoleOptions ) -> RenderResult: """Format the group for display in console.""" - table = Table(title=self.anchor_span.text, title_justify="left", show_header=False) + table = Table( + title=self.anchor_span.text, + title_justify="left", + show_header=False, + box=None, + ) table.add_column("doc", no_wrap=True) table.add_column("text") table.add_column("transcription") + # TODO: fix padding here + # render the "anchor" span first (i.e., the span that all matches share) table.add_row( f"{self.doc._.id} ({self.start}–{self.end-1})", @@ -222,9 +229,9 @@ def filter(self, predicate: Callable[[Match], bool]) -> None: # helper for getting bounds of a match in a given document def _bounds_in(doc): def _bounds(match): - if match.utxt.doc == doc: + if match.u == doc._.id: return match.utxt.start, match.utxt.end - if match.vtxt.doc == doc: + if match.v == doc._.id: return match.vtxt.start, match.vtxt.end raise ValueError("Match does not belong to document.", match, doc)