Skip to content

Commit

Permalink
Fix laser_similarity when batch is empty
Browse files Browse the repository at this point in the history
The batch will be empty sometimes because `threshold.py` is caching the results. Fixes #113.
  • Loading branch information
jelmervdl committed Aug 21, 2023
1 parent ab02f0c commit 05943d1
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions opuscleaner/filters/laser_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@


def _compute_similarity(laser: Laser, batch: List[Tuple[str, str]], src_lang: str, tgt_lang: str) -> List[float]:
assert len(batch) > 0
embeddings_src = laser.embed_sentences([line[0] for line in batch], lang=src_lang)
embeddings_tgt = laser.embed_sentences([line[1] for line in batch], lang=tgt_lang)

return [float(sim) for sim in _cosine_sim(embeddings_src, embeddings_tgt)]


Expand All @@ -41,7 +41,7 @@ def chunked(iterable: Iterable[T], *, chunk_size:Optional[int]=None, chunk_time:
it = iter(iterable)

# Initial set of measurements we then interpolate from
limit_samples = iter([8, 16, 32, 64, 128, 256])
limit_samples = iter([32, 64, 128, 256, 512, 1024])

# Chunk size limit for the next chunk
limit = chunk_size or next(limit_samples)
Expand All @@ -53,6 +53,12 @@ def chunked(iterable: Iterable[T], *, chunk_size:Optional[int]=None, chunk_time:
# Create a chunk
chunk = [el for _, el in zip(range(limit), it)]

# Did we reach the end because the last read was accidentally
# exactly the remainder of the dataset? Or because there was no
# input to begin with?
if not chunk:
return

# Measure how long it takes before we are asked for the next chunk
yield_time = time.monotonic()
yield chunk
Expand Down Expand Up @@ -80,7 +86,7 @@ def main():
description="Filter a parallel dataset using LASER.")
parser.add_argument("--verbose", action="store_true", help="Print tuning info")
parser.add_argument("--batch-size", type=int, help="LASER batch size")
parser.add_argument("--batch-latency", type=float, default=30.0, help="Tune batch size to process a batch every N seconds (defaults to 30s, ignored if --batch-size is given)")
parser.add_argument("--batch-latency", type=float, default=10.0, help="Tune batch size to process a batch every N seconds (defaults to 30s, ignored if --batch-size is given)")

This comment has been minimized.

Copy link
@kpu

kpu Aug 21, 2023

Documentation doesn't match code

parser.add_argument("--src-lang", type=str, required=True, help="Two-letter source language code (ISO 639-1)")
parser.add_argument("--tgt-lang", type=str, required=True, help="Two-letter target language code (ISO 639-1)")

Expand All @@ -91,7 +97,7 @@ def main():
args = parser.parse_args()

if not args.scores and args.threshold is None:
print("Either use --threshold or --scores")
print("Either use --threshold or --scores", file=sys.stderr)

laser = Laser()

Expand All @@ -101,7 +107,7 @@ def main():

if args.scores:
for score in scores:
print(score, file=sys.stdout)
print(score)
else:
for line, score in zip(batch, scores):
if score >= args.threshold:
Expand Down

0 comments on commit 05943d1

Please sign in to comment.