diff --git a/requirements-service.txt b/requirements-service.txt index 1b189b7..5d0f925 100644 --- a/requirements-service.txt +++ b/requirements-service.txt @@ -7,8 +7,8 @@ tqdm==4.32.2 neuralcoref==4.0 argparse scikit-learn -bert-extractive-summarizer==0.6.1 +bert-extractive-summarizer Flask flask-cors nltk -https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz +https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz \ No newline at end of file diff --git a/summarizer/sentence_handler.py b/summarizer/sentence_handler.py index 5808cce..a9fbee5 100644 --- a/summarizer/sentence_handler.py +++ b/summarizer/sentence_handler.py @@ -7,7 +7,26 @@ class SentenceHandler(object): def __init__(self, language=English): self.nlp = language() - self.nlp.add_pipe(self.nlp.create_pipe('sentencizer')) + + try: + self.nlp.add_pipe(self.nlp.create_pipe('sentencizer')) + self.is_spacy_3 = False + except: + self.nlp.add_pipe("sentencizer") + self.is_spacy_3 = True + + def sentence_processor(self, doc, min_length: int = 40, max_length: int = 600): + to_return = [] + + for c in doc.sents: + if max_length > len(c.text.strip()) > min_length: + + if self.is_spacy_3: + to_return.append(c.text.strip()) + else: + to_return.append(c.string.strip()) + + return to_return def process(self, body: str, min_length: int = 40, max_length: int = 600) -> List[str]: """ @@ -19,7 +38,7 @@ def process(self, body: str, min_length: int = 40, max_length: int = 600) -> Lis :return: Returns a list of sentences. """ doc = self.nlp(body) - return [c.string.strip() for c in doc.sents if max_length > len(c.string.strip()) > min_length] + return self.sentence_processor(doc, min_length, max_length) def __call__(self, body: str, min_length: int = 40, max_length: int = 600) -> List[str]: return self.process(body, min_length, max_length) diff --git a/tests/test_summary_items.py b/tests/test_summary_items.py index 2567175..103605d 100644 --- a/tests/test_summary_items.py +++ b/tests/test_summary_items.py @@ -19,7 +19,7 @@ def summarizer(): @pytest.fixture() def summarizer_multi_hidden(): - return Summarizer('distilbert-base-uncased', hidden=[-1,-2,-3]) + return Summarizer('distilbert-base-uncased', hidden=[-1,-2]) @pytest.fixture()