Skip to content

Commit

Permalink
Make _tokenize aware if it is called on XML input
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Proisl committed Oct 20, 2023
1 parent abe15d8 commit afad123
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/somajo/somajo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import functools
import itertools
import multiprocessing

Expand Down Expand Up @@ -46,29 +47,43 @@ def __init__(self, language, *, split_camel_case=False, split_sentences=True, xm
if self.split_sentences:
self._sentence_splitter = SentenceSplitter(language=self.language)

def _tokenize(self, token_list):
def _tokenize(self, token_info, xml_input):
"""Tokenize and sentence split a single token_dll."""
# unpack token_info
# resolve entities in xml
# convert raw to nfc
# align nfc
# tokenize
# find character offsets
token_list = token_info
token_dll = doubly_linked_list.DLL(token_list)
tokens = self._tokenizer._tokenize(token_dll)
if self.split_sentences:
tokens = self._sentence_splitter._split_sentences(tokens)
return tokens

def _parallel_tokenize(self, token_lists, *, parallel=1, strip_tags=False):
def _parallel_tokenize(self, token_info, *, parallel=1, strip_tags=False, xml_input=False):
"""Tokenize and sentence split an iterable of token_dlls; optional
parallelization.
"""
def partok():
with multiprocessing.Pool(min(parallel, multiprocessing.cpu_count())) as pool:
tokens = pool.imap(self._tokenize, token_lists, 250)
tokens = pool.imap(
functools.partial(self._tokenize, xml_input=xml_input),
token_info,
250
)
for par in tokens:
yield par

if parallel > 1:
tokens = partok()
else:
tokens = map(self._tokenize, token_lists)
tokens = map(
functools.partial(self._tokenize, xml_input=xml_input),
token_info
)
if self.split_sentences:
tokens = itertools.chain.from_iterable(tokens)
tokens = self._sentence_splitter._merge_empty_sentences(tokens)
Expand All @@ -89,8 +104,8 @@ def _tokenize_xml(self, xml_data, is_file, eos_tags, strip_tags, parallel, prune
eos_tags = set(eos_tags)
if prune_tags is not None:
prune_tags = set(prune_tags)
token_lists = utils.xml_chunk_generator(xml_data, is_file, eos_tags=eos_tags, prune_tags=prune_tags)
tokens = self._parallel_tokenize(token_lists, parallel=parallel, strip_tags=strip_tags)
token_info = utils.xml_chunk_generator(xml_data, is_file, eos_tags=eos_tags, prune_tags=prune_tags)
tokens = self._parallel_tokenize(token_info, parallel=parallel, strip_tags=strip_tags, xml_input=True)
if not (strip_tags and self.xml_sentences is None):
tokens = map(utils.escape_xml_tokens, tokens)
return tokens
Expand Down

0 comments on commit afad123

Please sign in to comment.