Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Oct 13, 2023
1 parent b3d12e4 commit e8ebf39
Show file tree
Hide file tree
Showing 9 changed files with 560 additions and 1 deletion.
12 changes: 12 additions & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
GROBID_URL=https://lfoppiano-grobid.hf.space
PROMPTLAYER_API_KEY=pl_89d0213d12cde8162f75c4da00e74094

HTTP_PROXY=
http_proxy=
HTTPS_PROXY=
https_proxy=
NO_PROXY=
no_proxy=
REQUEST_CA_BUNDLE=
REQUESTS_CA_BUNDLE=
CURL_CA_BUNDLE=
45 changes: 45 additions & 0 deletions .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Build unstable

on: [push]

concurrency:
group: unstable
# cancel-in-progress: true


jobs:
build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --upgrade flake8 pytest pycodestyle
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Test with pytest
# run: |
# pytest

docker-build-documentqa:
needs: [build]

runs-on: self-hosted

steps:
- uses: actions/checkout@v2
- name: Build the Docker image
run: docker build . --file Dockerfile.qa --tag lfoppiano/documentqa:develop-latest
- name: Cleanup older than 24h images and containers
run: docker system prune --filter "until=24h" --force
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.idea
.env
.env.docker
**/**/.chroma
grobid_magneto/reverse_qa/.chroma
exploration_llm
resources/db
5 changes: 5 additions & 0 deletions .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[logger]
level = "info"

[browser]
gatherUsageStats = false
33 changes: 33 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
FROM python:3.9-slim

WORKDIR /app

RUN apt-get update && apt-get install -y \
build-essential \
curl \
software-properties-common \
git \
&& rm -rf /var/lib/apt/lists/*

COPY requirements.txt .

RUN pip3 install -r requirements.txt --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org

COPY grobid_magneto/ grobid_magneto
COPY commons/ commons
COPY resources/nims_proxy.cer .
COPY tiktoken_cache ./tiktoken_cache
COPY grobid_magneto/document_qa/.streamlit ./.streamlit

# extract version
COPY .git ./.git
RUN git rev-parse --short HEAD > revision.txt
RUN rm -rf ./.git

EXPOSE 8501

HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health

ENV PYTHONPATH "${PYTHONPATH}:."

ENTRYPOINT ["streamlit", "run", "document_qa/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# document-qa
# DocumentIQA: Document Insight Question/Answer

Small demo for performing data extraction at document level using small context.
251 changes: 251 additions & 0 deletions document_qa_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import copy
import os
from pathlib import Path
from typing import Union, Any

from grobid_client.grobid_client import GrobidClient
from langchain.chains import create_extraction_chain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.retrievers import MultiQueryRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from tqdm import tqdm

from commons.annotations_utils import GrobidProcessor


class DocumentQAEngine:
llm = None
qa_chain_type = None
embedding_function = None
embeddings_dict = {}
embeddings_map_from_md5 = {}
embeddings_map_to_md5 = {}

def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None):
self.embedding_function = embedding_function
self.llm = llm
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)

if embeddings_root_path is not None:
self.embeddings_root_path = embeddings_root_path
if not os.path.exists(embeddings_root_path):
os.makedirs(embeddings_root_path)
else:
self.load_embeddings(self.embeddings_root_path)

if grobid_url:
self.grobid_url = grobid_url
grobid_client = GrobidClient(
grobid_server=self.grobid_url,
batch_size=1000,
coordinates=["p"],
sleep_time=5,
timeout=60,
check_server=True
)
self.grobid_processor = GrobidProcessor(grobid_client)

def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None:
"""
Load the embeddings assuming they are all persisted and stored in a single directory.
The root path of the embeddings containing one data store for each document in each subdirectory
"""

embeddings_directories = [f for f in os.scandir(embeddings_root_path) if f.is_dir()]

if len(embeddings_directories) == 0:
print("No available embeddings")
return

for embedding_document_dir in embeddings_directories:
self.embeddings_dict[embedding_document_dir.name] = Chroma(persist_directory=embedding_document_dir.path,
embedding_function=self.embedding_function)

filename_list = list(Path(embedding_document_dir).glob('*.storage_filename'))
if filename_list:
filenam = filename_list[0].name.replace(".storage_filename", "")
self.embeddings_map_from_md5[embedding_document_dir.name] = filenam
self.embeddings_map_to_md5[filenam] = embedding_document_dir.name

print("Embedding loaded: ", len(self.embeddings_dict.keys()))

def get_loaded_embeddings_ids(self):
return list(self.embeddings_dict.keys())

def get_md5_from_filename(self, filename):
return self.embeddings_map_to_md5[filename]

def get_filename_from_md5(self, md5):
return self.embeddings_map_from_md5[md5]

def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
verbose=False) -> (
Any, str):
# self.load_embeddings(self.embeddings_root_path)

if verbose:
print(query)

response = self._run_query(doc_id, query, context_size=context_size)
response = response['output_text'] if 'output_text' in response else response

if verbose:
print(doc_id, "->", response)

if output_parser:
try:
return self._parse_json(response, output_parser), response
except Exception as oe:
print("Failing to parse the response", oe)
return None, response
elif extraction_schema:
try:
chain = create_extraction_chain(extraction_schema, self.llm)
parsed = chain.run(response)
return parsed, response
except Exception as oe:
print("Failing to parse the response", oe)
return None, response
else:
return None, response

def query_storage(self, query: str, doc_id, context_size=4):
documents = self._get_context(doc_id, query, context_size)

context_as_text = [doc.page_content for doc in documents]
return context_as_text

def _parse_json(self, response, output_parser):
system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \
"that can process text and transform it to JSON."
human_message = """Transform the text between three double quotes in JSON.\n\n\n\n
{format_instructions}\n\nText: \"\"\"{text}\"\"\""""

system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message)

prompt_template = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

results = self.llm(
prompt_template.format_prompt(
text=response,
format_instructions=output_parser.get_format_instructions()
).to_messages()
)
parsed_output = output_parser.parse(results.content)

return parsed_output

def _run_query(self, doc_id, query, context_size=4):
relevant_documents = self._get_context(doc_id, query, context_size)
return self.chain.run(input_documents=relevant_documents, question=query)
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)

def _get_context(self, doc_id, query, context_size=4):
db = self.embeddings_dict[doc_id]
retriever = db.as_retriever(search_kwargs={"k": context_size})
relevant_documents = retriever.get_relevant_documents(query)
return relevant_documents

def get_all_context_by_document(self, doc_id):
db = self.embeddings_dict[doc_id]
docs = db.get()
return docs['documents']

def _get_context_multiquery(self, doc_id, query, context_size=4):
db = self.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size})
multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm)
relevant_documents = multi_query_retriever.get_relevant_documents(query)
return relevant_documents

def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
if verbose:
print("File", pdf_file_path)
filename = Path(pdf_file_path).stem
structure = self.grobid_processor.process_structure(pdf_file_path)

biblio = structure['biblio']
biblio['filename'] = filename.replace(" ", "_")

if verbose:
print("Generating embeddings for:", hash, ", filename: ", filename)

texts = []
metadatas = []
ids = []
if chunk_size < 0:
for passage in structure['passages']:
biblio_copy = copy.copy(biblio)
if len(str.strip(passage['text'])) > 0:
texts.append(passage['text'])

biblio_copy['type'] = passage['type']
biblio_copy['section'] = passage['section']
biblio_copy['subSection'] = passage['subSection']
metadatas.append(biblio_copy)

ids.append(passage['passage_id'])
else:
document_text = " ".join([passage['text'] for passage in structure['passages']])
# text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=chunk_size,
chunk_overlap=chunk_size * perc_overlap
)
texts = text_splitter.split_text(document_text)
metadatas = [biblio for _ in range(len(texts))]
ids = [id for id, t in enumerate(texts)]

return texts, metadatas, ids

def create_memory_embeddings(self, pdf_path, doc_id=None):
texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=500, perc_overlap=0.1)
if doc_id:
hash = doc_id
else:
hash = metadata[0]['hash']

self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata)
self.embeddings_root_path = None

return hash

def create_embeddings(self, pdfs_dir_path: Path):
input_files = []
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
for file_ in files:
if not (file_.lower().endswith(".pdf")):
continue
input_files.append(os.path.join(root, file_))

for input_file in tqdm(input_files, total=len(input_files), unit='document',
desc="Grobid + embeddings processing"):

md5 = self.calculate_md5(input_file)
data_path = os.path.join(self.embeddings_root_path, md5)

if os.path.exists(data_path):
print(data_path, "exists. Skipping it ")
continue

texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1)
filename = metadata[0]['filename']

vector_db_document = Chroma.from_texts(texts,
metadatas=metadata,
embedding=self.embedding_function,
persist_directory=data_path)
vector_db_document.persist()

with open(os.path.join(data_path, filename + ".storage_filename"), 'w') as fo:
fo.write("")

@staticmethod
def calculate_md5(input_file: Union[Path, str]):
import hashlib
md5_hash = hashlib.md5()
with open(input_file, 'rb') as fi:
md5_hash.update(fi.read())
return md5_hash.hexdigest().upper()
Loading

0 comments on commit e8ebf39

Please sign in to comment.