Skip to content

Commit

Permalink
Add CLI options for output file and grouping
Browse files Browse the repository at this point in the history
  • Loading branch information
thatbudakguy committed Jul 28, 2024
1 parent 633909a commit 6aba58d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
65 changes: 37 additions & 28 deletions dphon/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
Set input format. Currently, plaintext (txt) and JSON lines (jsonl) are
supported.
-o <FMT>, --output-format <FMT> [default: txt]
Set output format. Currently, plaintext (txt), JSON lines (jsonl),
comma-separated values (csv), and html (html) are supported. Note that
you still need to redirect to a file in order to save the output.
-o <PATH>, --output-file <PATH>
Set output filename. By default, output is sent to the terminal. The
output format is determined by the file extension. Supported formats
are JSON lines (jsonl), CSV (csv), and HTML (html).
Matching Options:
-n <NUM>, --ngram-order <NUM> [default: 4]
Expand Down Expand Up @@ -92,18 +92,18 @@
import csv
import logging
import os
import sys
import time
from itertools import combinations
from pathlib import Path
from typing import Dict
from typing import Dict, List

import jsonlines
import pkg_resources
import spacy
from docopt import docopt
from rich import traceback
from rich.logging import RichHandler
from rich.padding import Padding
from rich.progress import BarColumn, Progress, SpinnerColumn
from spacy.language import Language
from spacy.tokens import Doc
Expand All @@ -115,7 +115,7 @@
from .extend import LevenshteinPhoneticExtender
from .g2p import get_sound_table_json
from .match import Match
from .reuse import MatchGraph
from .reuse import MatchGraph, MatchGroup

# Available log levels: default is WARN, -v is INFO, -vv is DEBUG
LOG_LEVELS = {
Expand Down Expand Up @@ -149,36 +149,45 @@ def run() -> None:

# process all texts
graph = process(nlp, args)
results = list(graph.matches)
logging.info(f"{len(results)} total results matching query")

# sort results by highest total score
results = sorted(results, key=lambda m: m.weighted_score, reverse=True)
# check if we're outputting to a file and find out the format
if args["--output-file"]:
output_path = Path(args["--output-file"])
output_format = output_path.suffix.lstrip(".").lower()
if output_format not in ["jsonl", "csv", "html"]:
raise ValueError(f"unsupported output format: {output_format}")

# if requested output match groups, otherwise output matches
results: List[MatchGroup] | List[Match] = []
if args["--group"]:
results = graph.groups
else:
results = list(graph.matches)

# sort results by highest weighted score
results = sorted(results, key=lambda result: result.weighted_score, reverse=True)

# output depending on provided option
if args["--output-format"] == "jsonl":
with jsonlines.Writer(sys.stdout) as writer:
for match in results:
writer.write(match.as_dict())
elif args["--output-format"] == "csv":
if output_format == "jsonl":
with jsonlines.Writer(output_path) as writer:
for result in results:
writer.write(result.as_dict())
elif output_format == "csv":
fieldnames = Match("", "", "", "").as_dict().keys()
writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames)
writer = csv.DictWriter(output_path, fieldnames=fieldnames)
writer.writeheader()
for match in results:
writer.writerow(match.as_dict())
elif args["--output-format"] == "html":
for result in results:
writer.writerow(result.as_dict())
elif output_format == "html":
console.record = True
with console.capture():
for doc in graph.docs:
for group in doc._.groups:
console.print(group)
sys.stdout.write(console.export_html())
for result in results:
console.print(Padding(result, (0, 0, 1, 0)))
console.save_html(output_path)
else:
# use system pager by default; colorize if LESS=R
with console.pager(styles=os.getenv("LESS", "") == "R"):
for doc in graph.docs:
for group in doc._.groups:
console.print(group)
for result in results:
console.print(Padding(result, (0, 0, 1, 0)))


def setup(args: Dict) -> Language:
Expand Down
13 changes: 10 additions & 3 deletions dphon/reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,18 @@ def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
"""Format the group for display in console."""
table = Table(title=self.anchor_span.text, title_justify="left", show_header=False)
table = Table(
title=self.anchor_span.text,
title_justify="left",
show_header=False,
box=None,
)
table.add_column("doc", no_wrap=True)
table.add_column("text")
table.add_column("transcription")

# TODO: fix padding here

# render the "anchor" span first (i.e., the span that all matches share)
table.add_row(
f"{self.doc._.id} ({self.start}{self.end-1})",
Expand Down Expand Up @@ -222,9 +229,9 @@ def filter(self, predicate: Callable[[Match], bool]) -> None:
# helper for getting bounds of a match in a given document
def _bounds_in(doc):
def _bounds(match):
if match.utxt.doc == doc:
if match.u == doc._.id:
return match.utxt.start, match.utxt.end
if match.vtxt.doc == doc:
if match.v == doc._.id:
return match.vtxt.start, match.vtxt.end
raise ValueError("Match does not belong to document.", match, doc)

Expand Down

0 comments on commit 6aba58d

Please sign in to comment.