diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index f4b6607..88f913b 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -3,18 +3,89 @@ from pathlib import Path from typing import Union, Any -from document_qa.grobid_processors import GrobidProcessor +import tiktoken from grobid_client.grobid_client import GrobidClient -from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain +from langchain.chains import create_extraction_chain from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \ map_rerank_prompt from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from tqdm import tqdm +from document_qa.grobid_processors import GrobidProcessor + + +class TextMerger: + def __init__(self, model_name=None, encoding_name="gpt2"): + if model_name is not None: + self.enc = tiktoken.encoding_for_model(model_name) + else: + self.enc = tiktoken.get_encoding(encoding_name) + + def encode(self, text, allowed_special=set(), disallowed_special="all"): + return self.enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + + def merge_passages(self, passages, chunk_size, tolerance=0.2): + new_passages = [] + new_coordinates = [] + current_texts = [] + current_coordinates = [] + for idx, passage in enumerate(passages): + text = passage['text'] + coordinates = passage['coordinates'] + current_texts.append(text) + current_coordinates.append(coordinates) + + accumulated_text = " ".join(current_texts) + + encoded_accumulated_text = self.encode(accumulated_text) + + if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance: + if len(current_texts) > 1: + new_passages.append(current_texts[:-1]) + new_coordinates.append(current_coordinates[:-1]) + current_texts = [current_texts[-1]] + current_coordinates = [current_coordinates[-1]] + else: + new_passages.append(current_texts) + new_coordinates.append(current_coordinates) + current_texts = [] + current_coordinates = [] + + elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance: + new_passages.append(current_texts) + new_coordinates.append(current_coordinates) + current_texts = [] + current_coordinates = [] + else: + print("bao") + + if len(current_texts) > 0: + new_passages.append(current_texts) + new_coordinates.append(current_coordinates) + + new_passages_struct = [] + for i, passages in enumerate(new_passages): + text = " ".join(passages) + coordinates = ";".join(new_coordinates[i]) + + new_passages_struct.append( + { + "text": text, + "coordinates": coordinates, + "type": "aggregated chunks", + "section": "mixed", + "subSection": "mixed" + } + ) + + return new_passages_struct class DocumentQAEngine: @@ -44,6 +115,7 @@ def __init__(self, self.llm = llm self.memory = memory self.chain = load_qa_chain(llm, chain_type=qa_chain_type) + self.text_merger = TextMerger() if embeddings_root_path is not None: self.embeddings_root_path = embeddings_root_path @@ -157,7 +229,9 @@ def _parse_json(self, response, output_parser): def _run_query(self, doc_id, query, context_size=4): relevant_documents = self._get_context(doc_id, query, context_size) - relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] for doc in relevant_documents] #filter(lambda d: d['type'] == "sentence", relevant_documents)] + relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] + for doc in + relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)] response = self.chain.run(input_documents=relevant_documents, question=query) @@ -196,7 +270,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, if verbose: print("File", pdf_file_path) filename = Path(pdf_file_path).stem - coordinates = True if chunk_size == -1 else False + coordinates = True # if chunk_size == -1 else False structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates) biblio = structure['biblio'] @@ -209,29 +283,25 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, metadatas = [] ids = [] - if chunk_size < 0: - for passage in structure['passages']: - biblio_copy = copy.copy(biblio) - if len(str.strip(passage['text'])) > 0: - texts.append(passage['text']) + if chunk_size > 0: + new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size) + else: + new_passages = structure['passages'] - biblio_copy['type'] = passage['type'] - biblio_copy['section'] = passage['section'] - biblio_copy['subSection'] = passage['subSection'] - biblio_copy['coordinates'] = passage['coordinates'] - metadatas.append(biblio_copy) + for passage in new_passages: + biblio_copy = copy.copy(biblio) + if len(str.strip(passage['text'])) > 0: + texts.append(passage['text']) - ids.append(passage['passage_id']) - else: - document_text = " ".join([passage['text'] for passage in structure['passages']]) - # text_splitter = CharacterTextSplitter.from_tiktoken_encoder( - text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=chunk_size, - chunk_overlap=chunk_size * perc_overlap - ) - texts = text_splitter.split_text(document_text) - metadatas = [biblio for _ in range(len(texts))] - ids = [id for id, t in enumerate(texts)] + biblio_copy['type'] = passage['type'] + biblio_copy['section'] = passage['section'] + biblio_copy['subSection'] = passage['subSection'] + biblio_copy['coordinates'] = passage['coordinates'] + metadatas.append(biblio_copy) + + # ids.append(passage['passage_id']) + + ids = [id for id, t in enumerate(new_passages)] return texts, metadatas, ids diff --git a/tests/test_document_qa_engine.py b/tests/test_document_qa_engine.py new file mode 100644 index 0000000..959846d --- /dev/null +++ b/tests/test_document_qa_engine.py @@ -0,0 +1,71 @@ +from document_qa.document_qa_engine import TextMerger + + +def test_merge_passages_small_chunk(): + merger = TextMerger() + + passages = [ + { + 'text': "The quick brown fox jumps over the tree", + 'coordinates': '1' + }, + { + 'text': "and went straight into the mouth of a bear.", + 'coordinates': '2' + }, + { + 'text': "The color of the colors is a color with colors", + 'coordinates': '3' + }, + { + 'text': "the main colors are not the colorw we show", + 'coordinates': '4' + } + ] + new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0) + + assert len(new_passages) == 4 + assert new_passages[0]['coordinates'] == "1" + assert new_passages[0]['text'] == "The quick brown fox jumps over the tree" + + assert new_passages[1]['coordinates'] == "2" + assert new_passages[1]['text'] == "and went straight into the mouth of a bear." + + assert new_passages[2]['coordinates'] == "3" + assert new_passages[2]['text'] == "The color of the colors is a color with colors" + + assert new_passages[3]['coordinates'] == "4" + assert new_passages[3]['text'] == "the main colors are not the colorw we show" + + +def test_merge_passages_big_chunk(): + merger = TextMerger() + + passages = [ + { + 'text': "The quick brown fox jumps over the tree", + 'coordinates': '1' + }, + { + 'text': "and went straight into the mouth of a bear.", + 'coordinates': '2' + }, + { + 'text': "The color of the colors is a color with colors", + 'coordinates': '3' + }, + { + 'text': "the main colors are not the colorw we show", + 'coordinates': '4' + } + ] + new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0) + + assert len(new_passages) == 2 + assert new_passages[0]['coordinates'] == "1;2" + assert new_passages[0][ + 'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear." + + assert new_passages[1]['coordinates'] == "3;4" + assert new_passages[1][ + 'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show"