Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 7, 2024
1 parent 5264271 commit 88b5174
Show file tree
Hide file tree
Showing 43 changed files with 114 additions and 121 deletions.
2 changes: 1 addition & 1 deletion ac_dc/anonymization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def apply_regex_anonymization(
tag_type=tag_type,
)
if anonymize_condition:
for (ent, start, end, tag) in ner:
for ent, start, end, tag in ner:
# we need to actually walk through and replace by start, end span.
sentence = sentence.replace(ent, f" <{tag}> ")
return sentence, ner
3 changes: 1 addition & 2 deletions ac_dc/deduplicate/self_deduplicate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date : 2022-01-08 22:39:29
# @Author : Chenghao Mou (mouchenghao@gmail.com)
# @Description: Self-deduplication with `datasets`
Expand Down Expand Up @@ -28,7 +27,7 @@

def main(conf: str) -> None:

with open(conf, "r") as f:
with open(conf) as f:
conf = yaml.safe_load(f.read())

if conf["load_from_disk"]["path"]:
Expand Down
6 changes: 3 additions & 3 deletions ac_dc/visualization/get_data_for_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def compute_stats(self):
)
for n in range(2, 16)
}
stats_document[
"character_repetition_ratio"
] = character_repetition_ratios
stats_document["character_repetition_ratio"] = (
character_repetition_ratios
)

word_repetition_ratios = {
n: round(
Expand Down
40 changes: 20 additions & 20 deletions ac_dc/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,16 @@ def get_cond(key, cutoff, max_cutoff):
"stopwords_ratio"
]
for i in range(len(self.docs["stopwords_ratio"])):
self.docs["stopwords_ratio"].iloc[
i
] = Filtering.compute_stopwords_ratio(
self.docs["text"].iloc[i],
self.sentencepiece_model_tok,
self.param["strip_characters"],
self.param["cond_words_augmentation"],
self.param["words_augmentation_group_sizes"],
self.param["words_augmentation_join_char"],
new_stopwords,
self.docs["stopwords_ratio"].iloc[i] = (
Filtering.compute_stopwords_ratio(
self.docs["text"].iloc[i],
self.sentencepiece_model_tok,
self.param["strip_characters"],
self.param["cond_words_augmentation"],
self.param["words_augmentation_group_sizes"],
self.param["words_augmentation_join_char"],
new_stopwords,
)
)
cutoff_def = "If the stop words ratio of a document is lower than this number, the document is removed."
cutoff_stopwords_ratio = st.slider(
Expand All @@ -326,16 +326,16 @@ def get_cond(key, cutoff, max_cutoff):
"flagged_words_ratio"
]
for i in range(len(self.docs["flagged_words_ratio"])):
self.docs["flagged_words_ratio"].iloc[
i
] = Filtering.compute_flagged_words_ratio(
self.docs["text"].iloc[i],
self.sentencepiece_model_tok,
self.param["strip_characters"],
self.param["cond_words_augmentation"],
self.param["words_augmentation_group_sizes"],
self.param["words_augmentation_join_char"],
new_flagged_words,
self.docs["flagged_words_ratio"].iloc[i] = (
Filtering.compute_flagged_words_ratio(
self.docs["text"].iloc[i],
self.sentencepiece_model_tok,
self.param["strip_characters"],
self.param["cond_words_augmentation"],
self.param["words_augmentation_group_sizes"],
self.param["words_augmentation_join_char"],
new_flagged_words,
)
)
cutoff_def = "If the flagged words ratio of a document is higher than this number, the document is removed."
max_fwr = np.max(self.docs["flagged_words_ratio"])
Expand Down
17 changes: 10 additions & 7 deletions bertin/evaluation/run_glue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -384,19 +383,23 @@ def main():
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name
if model_args.config_name
else model_args.model_name_or_path,
(
model_args.config_name
if model_args.config_name
else model_args.model_name_or_path
),
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name
if model_args.tokenizer_name
else model_args.model_name_or_path,
(
model_args.tokenizer_name
if model_args.tokenizer_name
else model_args.model_name_or_path
),
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
Expand Down
15 changes: 8 additions & 7 deletions bertin/evaluation/run_ner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -364,9 +363,11 @@ def get_label_list(labels):
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name
if model_args.config_name
else model_args.model_name_or_path,
(
model_args.config_name
if model_args.config_name
else model_args.model_name_or_path
),
num_labels=num_labels,
label2id=label_to_id,
id2label={i: l for l, i in label_to_id.items()},
Expand Down Expand Up @@ -636,9 +637,9 @@ def compute_metrics(p):
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs[
"dataset"
] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
kwargs["dataset"] = (
f"{data_args.dataset_name} {data_args.dataset_config_name}"
)
else:
kwargs["dataset"] = data_args.dataset_name

Expand Down
3 changes: 1 addition & 2 deletions bertin/mc4/mc4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Perplexity Sampled mC4 dataset based on Common Crawl."""


import gzip
import json

Expand Down Expand Up @@ -404,7 +403,7 @@ def _generate_examples(self, filepaths):
for filepath in filepaths:
logger.info("generating examples from = %s", filepath)
if filepath.endswith("jsonl"):
with open(filepath, "r", encoding="utf-8") as f:
with open(filepath, encoding="utf-8") as f:
for line in f:
if line:
example = json.loads(line)
Expand Down
1 change: 0 additions & 1 deletion bertin/run_mlm_flax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
3 changes: 1 addition & 2 deletions bertin/run_mlm_flax_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -446,7 +445,7 @@ def restore_checkpoint(save_dir, state):
args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))

with open(os.path.join(save_dir, "training_state.json"), "r") as f:
with open(os.path.join(save_dir, "training_state.json")) as f:
training_state = json.load(f)
step = training_state["step"]

Expand Down
2 changes: 1 addition & 1 deletion bertin/utils/dataset_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_perplexity(doc):


with open("mc4-es-train-50M-stats.csv", "w") as csv:
with open("mc4-es-train-50M-steps.jsonl", "r") as data:
with open("mc4-es-train-50M-steps.jsonl") as data:
for line in tqdm(data):
text = json.loads(line)["text"]
csv.write(f"{len(text.split())},{get_perplexity(text)}\n")
1 change: 1 addition & 0 deletions cc_pseudo_crawl/python_scripts/deeper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Generate list of urls to query for next depth. We then need to use Athena to make a fancy query.
"""

import csv
import re
import subprocess
Expand Down
4 changes: 2 additions & 2 deletions cc_pseudo_crawl/python_scripts/download_warc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def get_warcs(batch):
existing_compressed_warcs,
)

batch["compressed_warc"], batch["download_exception"] = [
batch["compressed_warc"], batch["download_exception"] = (
list(l) for l in zip(*warcs_or_exceptions)
]
)
return batch


Expand Down
1 change: 1 addition & 0 deletions cc_pseudo_crawl/python_scripts/exact_deduplicates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Taken from Teven and Leandro"""

import gzip
import os
import shutil
Expand Down
2 changes: 1 addition & 1 deletion cc_pseudo_crawl/python_scripts/load_all_seed_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main():

seed_ids = []
for seed_path in args.seed_paths:
with open(seed_path, "r") as fi:
with open(seed_path) as fi:
data = csv.reader(fi)
# First line is all the headers that we remove.
seed_ids += [row[0] for row_id, row in enumerate(data) if row_id > 0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def process_batch(batch, skip_set):
# looks at up to the first 10K pages for a seed and
# records lines that appear in at least 1% of the unique pages
def get_lines_to_skip(dset, n_records, pourcentage_threshold, min_repetition_threshold):
line_counts = defaultdict(lambda: 0)
line_counts = defaultdict(int)
seen_pages = set()

seed = SeedSequence(42)
Expand Down
1 change: 1 addition & 0 deletions cc_pseudo_crawl/python_scripts/shard_by_seed_id.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Deduplicating using `datasets` is much harder, we but we forgot to generate an id when building an index, so we're screwed.
"""

import logging
import subprocess
from argparse import ArgumentParser
Expand Down
3 changes: 1 addition & 2 deletions kenlm_training/cc_net/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@


class Executor(Protocol):
def __call__(self, function: Callable[..., str], *args: Iterable) -> None:
...
def __call__(self, function: Callable[..., str], *args: Iterable) -> None: ...


class SubmititRetryOnTimeout(submitit.helpers.Checkpointable):
Expand Down
18 changes: 6 additions & 12 deletions kenlm_training/cc_net/flat_hash_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,17 @@ def __repr__(self):
implementation = type(self).__name__
return f"[{implementation}, len: {len(self)}"

def __len__(self) -> int:
...
def __len__(self) -> int: ...

def __contains__(self, values: Sequence[np.uint64]) -> np.ndarray:
...
def __contains__(self, values: Sequence[np.uint64]) -> np.ndarray: ...

def __getitem__(self, values) -> np.ndarray:
...
def __getitem__(self, values) -> np.ndarray: ...

def __setitem__(self, keys, values) -> None:
...
def __setitem__(self, keys, values) -> None: ...

def items(self) -> Iterable[Tuple[np.uint64, np.uint8]]:
...
def items(self) -> Iterable[Tuple[np.uint64, np.uint8]]: ...

def keys(self) -> Iterable[np.uint64]:
...
def keys(self) -> Iterable[np.uint64]: ...

def __iter__(self) -> Iterator[np.uint64]:
return iter(self.keys())
Expand Down
19 changes: 7 additions & 12 deletions kenlm_training/cc_net/jsonql.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,8 +880,7 @@ def describe(source, columns=None, weights=None, **kwargs):
continue
if "." in k or k == ALL_DOCUMENTS:
continue
for line in display_stats(stats, k, weights=weights, **kwargs):
yield line
yield from display_stats(stats, k, weights=weights, **kwargs)


def shard(lines):
Expand All @@ -902,17 +901,13 @@ def get_or_set(dictionary, key, default):
class SimpleIO(Protocol):
"""A subset of methods from TextIO."""

def close(self) -> None:
...
def close(self) -> None: ...

def write(self, line: str) -> int:
...
def write(self, line: str) -> int: ...

def __enter__(self) -> "SimpleIO":
...
def __enter__(self) -> "SimpleIO": ...

def __exit__(self, exc_type, exc_value, traceback):
...
def __exit__(self, exc_type, exc_value, traceback): ...


def open_read(filename: ReadableFileLike) -> Iterable[str]:
Expand Down Expand Up @@ -961,7 +956,7 @@ def open_read(filename: ReadableFileLike) -> Iterable[str]:
if filename.suffix == ".gz":
file: TextIO = gzip.open(filename, "rt") # type: ignore
else:
file = open(filename, "rt")
file = open(filename)

return _close_when_exhausted(file)

Expand Down Expand Up @@ -1015,7 +1010,7 @@ def open_write(
if filename.suffix == ".gz":
return BlockedGzipWriter(Path(filename), mode, block_size="64M")

return open(filename, "wt")
return open(filename, "w")


def parse_size(size):
Expand Down
2 changes: 1 addition & 1 deletion kenlm_training/tests/test_jsonql.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def do(self, x):
def acc(values):
print("acc: started")
res = 0
for (x, _) in values:
for x, _ in values:
res += int(x)
print("acc: done")
yield f"acc: result={res}"
Expand Down
14 changes: 8 additions & 6 deletions perplexity_lenses/perplexity_lenses/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ def hub_dataset_to_dataframe(
{
text_column: sentence,
"perplexity": model.get_perplexity(sentence),
"label": x.get("labels", [])[0]
if len(x.get("labels", [])) > 0
else "NONE", # Special case for registry dataset
"label": (
x.get("labels", [])[0]
if len(x.get("labels", [])) > 0
else "NONE"
), # Special case for registry dataset
}
for sentence in x[text_column].split("\n")
]
Expand All @@ -46,9 +48,9 @@ def hub_dataset_to_dataframe(
lambda x: {
text_column: x[text_column],
"perplexity": model.get_perplexity(x[text_column]),
"label": x.get("labels", [])[0]
if len(x.get("labels", [])) > 0
else "NONE", # Special case for registry dataset
"label": (
x.get("labels", [])[0] if len(x.get("labels", [])) > 0 else "NONE"
), # Special case for registry dataset
}
)
instances = []
Expand Down
Loading

0 comments on commit 88b5174

Please sign in to comment.