Skip to content

Commit

Permalink
add merger that preserve the coordinates and aggregate them meaningfully
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Jan 17, 2024
1 parent aeb450e commit c07b97b
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 26 deletions.
122 changes: 96 additions & 26 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,89 @@
from pathlib import Path
from typing import Union, Any

from document_qa.grobid_processors import GrobidProcessor
import tiktoken
from grobid_client.grobid_client import GrobidClient
from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
from langchain.chains import create_extraction_chain
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
map_rerank_prompt
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.retrievers import MultiQueryRetriever
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from tqdm import tqdm

from document_qa.grobid_processors import GrobidProcessor


class TextMerger:
def __init__(self, model_name=None, encoding_name="gpt2"):
if model_name is not None:
self.enc = tiktoken.encoding_for_model(model_name)
else:
self.enc = tiktoken.get_encoding(encoding_name)

def encode(self, text, allowed_special=set(), disallowed_special="all"):
return self.enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)

def merge_passages(self, passages, chunk_size, tolerance=0.2):
new_passages = []
new_coordinates = []
current_texts = []
current_coordinates = []
for idx, passage in enumerate(passages):
text = passage['text']
coordinates = passage['coordinates']
current_texts.append(text)
current_coordinates.append(coordinates)

accumulated_text = " ".join(current_texts)

encoded_accumulated_text = self.encode(accumulated_text)

if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance:
if len(current_texts) > 1:
new_passages.append(current_texts[:-1])
new_coordinates.append(current_coordinates[:-1])
current_texts = [current_texts[-1]]
current_coordinates = [current_coordinates[-1]]
else:
new_passages.append(current_texts)
new_coordinates.append(current_coordinates)
current_texts = []
current_coordinates = []

elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance:
new_passages.append(current_texts)
new_coordinates.append(current_coordinates)
current_texts = []
current_coordinates = []
else:
print("bao")

if len(current_texts) > 0:
new_passages.append(current_texts)
new_coordinates.append(current_coordinates)

new_passages_struct = []
for i, passages in enumerate(new_passages):
text = " ".join(passages)
coordinates = ";".join(new_coordinates[i])

new_passages_struct.append(
{
"text": text,
"coordinates": coordinates,
"type": "aggregated chunks",
"section": "mixed",
"subSection": "mixed"
}
)

return new_passages_struct


class DocumentQAEngine:
Expand Down Expand Up @@ -44,6 +115,7 @@ def __init__(self,
self.llm = llm
self.memory = memory
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
self.text_merger = TextMerger()

if embeddings_root_path is not None:
self.embeddings_root_path = embeddings_root_path
Expand Down Expand Up @@ -157,7 +229,9 @@ def _parse_json(self, response, output_parser):

def _run_query(self, doc_id, query, context_size=4):
relevant_documents = self._get_context(doc_id, query, context_size)
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] for doc in relevant_documents] #filter(lambda d: d['type'] == "sentence", relevant_documents)]
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
for doc in
relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
response = self.chain.run(input_documents=relevant_documents,
question=query)

Expand Down Expand Up @@ -196,7 +270,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
if verbose:
print("File", pdf_file_path)
filename = Path(pdf_file_path).stem
coordinates = True if chunk_size == -1 else False
coordinates = True # if chunk_size == -1 else False
structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)

biblio = structure['biblio']
Expand All @@ -209,29 +283,25 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
metadatas = []
ids = []

if chunk_size < 0:
for passage in structure['passages']:
biblio_copy = copy.copy(biblio)
if len(str.strip(passage['text'])) > 0:
texts.append(passage['text'])
if chunk_size > 0:
new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size)
else:
new_passages = structure['passages']

biblio_copy['type'] = passage['type']
biblio_copy['section'] = passage['section']
biblio_copy['subSection'] = passage['subSection']
biblio_copy['coordinates'] = passage['coordinates']
metadatas.append(biblio_copy)
for passage in new_passages:
biblio_copy = copy.copy(biblio)
if len(str.strip(passage['text'])) > 0:
texts.append(passage['text'])

ids.append(passage['passage_id'])
else:
document_text = " ".join([passage['text'] for passage in structure['passages']])
# text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=chunk_size,
chunk_overlap=chunk_size * perc_overlap
)
texts = text_splitter.split_text(document_text)
metadatas = [biblio for _ in range(len(texts))]
ids = [id for id, t in enumerate(texts)]
biblio_copy['type'] = passage['type']
biblio_copy['section'] = passage['section']
biblio_copy['subSection'] = passage['subSection']
biblio_copy['coordinates'] = passage['coordinates']
metadatas.append(biblio_copy)

# ids.append(passage['passage_id'])

ids = [id for id, t in enumerate(new_passages)]

return texts, metadatas, ids

Expand Down
71 changes: 71 additions & 0 deletions tests/test_document_qa_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from document_qa.document_qa_engine import TextMerger


def test_merge_passages_small_chunk():
merger = TextMerger()

passages = [
{
'text': "The quick brown fox jumps over the tree",
'coordinates': '1'
},
{
'text': "and went straight into the mouth of a bear.",
'coordinates': '2'
},
{
'text': "The color of the colors is a color with colors",
'coordinates': '3'
},
{
'text': "the main colors are not the colorw we show",
'coordinates': '4'
}
]
new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0)

assert len(new_passages) == 4
assert new_passages[0]['coordinates'] == "1"
assert new_passages[0]['text'] == "The quick brown fox jumps over the tree"

assert new_passages[1]['coordinates'] == "2"
assert new_passages[1]['text'] == "and went straight into the mouth of a bear."

assert new_passages[2]['coordinates'] == "3"
assert new_passages[2]['text'] == "The color of the colors is a color with colors"

assert new_passages[3]['coordinates'] == "4"
assert new_passages[3]['text'] == "the main colors are not the colorw we show"


def test_merge_passages_big_chunk():
merger = TextMerger()

passages = [
{
'text': "The quick brown fox jumps over the tree",
'coordinates': '1'
},
{
'text': "and went straight into the mouth of a bear.",
'coordinates': '2'
},
{
'text': "The color of the colors is a color with colors",
'coordinates': '3'
},
{
'text': "the main colors are not the colorw we show",
'coordinates': '4'
}
]
new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0)

assert len(new_passages) == 2
assert new_passages[0]['coordinates'] == "1;2"
assert new_passages[0][
'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear."

assert new_passages[1]['coordinates'] == "3;4"
assert new_passages[1][
'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show"

0 comments on commit c07b97b

Please sign in to comment.