Skip to content

Commit

Permalink
feat: add remove first header option
Browse files Browse the repository at this point in the history
  • Loading branch information
MicPie committed Oct 12, 2023
1 parent 8f0ebe1 commit ec88132
Showing 1 changed file with 82 additions and 48 deletions.
130 changes: 82 additions & 48 deletions data/natural/preprocess_nougat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

from tqdm import tqdm

TEXT_CUTOFF = 0
STR_CUTOFF = 0
KEEP_FIRST_HEADERS = [
"main",
"abstract",
"introduction",
"summary",
]


def load_mmd_from_path(path):
Expand All @@ -15,10 +21,14 @@ def load_mmd_from_path(path):

rm_double_asterisk_start = re.compile(r"\n\*\*"), "\n## "
rm_double_asterisk_end = re.compile(r"\*\*\n"), ""
# rm_double_asterisk_end_inline = re.compile(r"\*\*.*\w"), "\n"
rm_double_asterisk = re.compile(r"\*\*"), ""
rm_missing_page_fail = re.compile(r"\n\n\[MISSING_PAGE_FAIL:\d+\]"), ""
rm_missing_page_empty = re.compile(r"\n\n\[MISSING_PAGE_EMPTY:\d+\]"), ""
rm_missing_page_post = re.compile(r"\n\n\[MISSING_PAGE_POST\]"), ""
rm_missing_page_fail_a = re.compile(r"\n\n\[MISSING_PAGE_FAIL:\d+\]"), ""
rm_missing_page_fail_b = re.compile(r"\[MISSING_PAGE_FAIL:\d+\]"), ""
rm_missing_page_empty_a = re.compile(r"\n\n\[MISSING_PAGE_EMPTY:\d+\]"), ""
rm_missing_page_empty_b = re.compile(r"\[MISSING_PAGE_EMPTY:\d+\]"), ""
rm_missing_page_post_a = re.compile(r"\n\n\[MISSING_PAGE_POST\]"), ""
rm_missing_page_post_b = re.compile(r"\[MISSING_PAGE_POST\]"), ""
rm_figure_caption_start = re.compile(r"[Ff]igure \d+\w?\.?[:\|]?\s"), ""
rm_schema_caption_start = re.compile(r"[Ss]chema \d+\w?\.?[:\|]?\s"), ""
rm_schema_caption_start = re.compile(r"[Ss]cheme \d+\w?\.?[:\|]?\s"), ""
Expand All @@ -37,45 +47,66 @@ def load_mmd_from_path(path):

year_numbers = re.compile(r"[19,20]\d\d\,")

find_headers = re.compile("(#{1,6}.*)\\n")


def get_headers(mmd, show=False):
headers = []
for match in find_headers.finditer(mmd):
span = match.span()
headers.append((mmd[span[0] : span[1]], span))
if show:
for h in headers:
print(h)
return headers


def get_next_header_to_remove(mmd, exclude_headers):
headers = get_headers(mmd)
for header, span in headers:
for eh in exclude_headers:
if header.lower().find(eh) != -1:
return (header, span)
return False


def remove_nested_headers(mmd, header_span, verbose):
header, span = header_span
count_hashtag = header.count("#")
headers = get_headers(mmd)
header_idx = headers.index(header_span)
for i, (next_header, next_span) in enumerate(headers):
if i + 1 == len(headers):
next_header_pos = len(mmd) - 1
if i <= header_idx:
continue
if count_hashtag == next_header.count("#"):
next_header_pos = next_span[0]
if verbose:
print(f"Removed span: {span[0]}:{next_header_pos}")
mmd = mmd[: span[0]] + mmd[next_header_pos + 1 :]
return mmd


def remove_first_header(mmd):
headers = get_headers(mmd)
header, span = headers[0]
if span[0] > 0:
_, next_span = headers[1]
mmd = mmd[next_span[0] :]

def clean_mmd(mmd, verbose=False):
headers = get_headers(mmd)
header, span = headers[0]
if all([header.lower().find(kfh) == -1 for kfh in KEEP_FIRST_HEADERS]):
_, next_span = headers[1]
mmd = mmd[next_span[0] :]
return mmd


def clean_mmd(mmd, rm_first_header=False, verbose=False):
# section cleaning
find_headers = re.compile("(#{1,6}.*)\\n")

def get_headers(mmd, show=False):
headers = []
for match in find_headers.finditer(mmd):
span = match.span()
headers.append((mmd[span[0] : span[1]], span))
if show:
for h in headers:
print(h)
return headers

def get_next_header_to_remove(mmd, exclude_headers):
headers = get_headers(mmd)
for header, span in headers:
for eh in exclude_headers:
if header.lower().find(eh) != -1:
return (header, span)
return False

def remove_header(mmd, header_span):
header, span = header_span
count_hashtag = header.count("#")
headers = get_headers(mmd)
header_idx = headers.index(header_span)
for i, (next_header, next_span) in enumerate(headers):
if i + 1 == len(headers):
next_header_pos = len(mmd) - 1
if i <= header_idx:
continue
if count_hashtag == next_header.count("#"):
next_header_pos = next_span[0]
if verbose:
print(f"Removed span: {span[0]}:{next_header_pos}")
mmd = mmd[: span[0]] + mmd[next_header_pos + 1 :]
return mmd
if rm_first_header:
mmd = remove_first_header(mmd)

if verbose:
_ = get_headers(mmd, show=True)
Expand All @@ -85,16 +116,19 @@ def remove_header(mmd, header_span):
if verbose:
print(f"{header_span=}")
if isinstance(header_span, tuple):
mmd = remove_header(mmd, header_span)
mmd = remove_nested_headers(mmd, header_span, verbose)

# low level cleaning
reg_replace = [
rm_double_asterisk_start,
rm_double_asterisk_end,
# rm_double_asterisk,
rm_missing_page_fail,
rm_missing_page_empty,
rm_missing_page_post,
rm_missing_page_fail_a,
rm_missing_page_fail_b,
rm_missing_page_empty_a,
rm_missing_page_empty_b,
rm_missing_page_post_a,
rm_missing_page_post_b,
rm_figure_caption_start,
rm_schema_caption_start,
rm_fig_caption_start,
Expand Down Expand Up @@ -163,8 +197,8 @@ def create_jsonl_from_dir(path):
fn = path.split("/")[-1].split(".mmd")[0]
pbar.set_postfix_str(fn)
mmd = load_mmd_from_path(path)
text = clean_mmd(mmd)
if len(text) <= TEXT_CUTOFF:
text = clean_mmd(mmd, rm_first_header=True, verbose=False)
if len(text) <= STR_CUTOFF:
print(f"Too short text in: {fn}")
elif text.count("Journal of") > 10:
print(f'Too many "Journal of" in text: {fn}')
Expand All @@ -180,5 +214,5 @@ def create_jsonl_from_dir(path):


if __name__ == "__main__":
path_base = ""
path_base = "/Users/MMP/Documents/Projects/chemnlp3/data/natural/nougat_output/"
create_jsonl_from_dir(path_base)

0 comments on commit ec88132

Please sign in to comment.