Skip to content

Commit

Permalink
Support for multiple files in RAG virtual assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianboguszewski committed Nov 26, 2024
1 parent 1e80af7 commit 9f5ec19
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions demos/virtual_ai_assistant_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,39 +124,45 @@ def load_chat_models(chat_model_name: str, embedding_model_name: str, personalit
memory=ChatMemoryBuffer.from_defaults())


def load_file(file_path: Path) -> Document:
ext = file_path.suffix
if ext == ".pdf":
# Using PyMuPDF (fitz) to read PDF content
text = ""
with fitz.open(file_path) as pdf:
for page in pdf:
text += page.get_text("text") + "\n" # Extract text from each page
return Document(text=text, metadata={"file_name": file_path.name})

elif ext == ".txt":
# Reading text files as usual
with open(file_path) as f:
content = f.read()
return Document(text=content, metadata={"file_name": file_path.name})

else:
raise ValueError(f"{ext} file is not supported for now")


def load_context(file_path: str) -> None:
def load_files(file_paths: List[str]) -> list[Document]:
documents = []
for file_path in map(lambda x: Path(x), file_paths):
ext = file_path.suffix
if ext == ".pdf":
# Using PyMuPDF (fitz) to read PDF content
text = ""
with fitz.open(file_path) as pdf:
for page in pdf:
text += page.get_text("text") + "\n" # Extract text from each page
# remove non-breaking space
text.replace("\xa0", " ")
documents.append(Document(text=text, metadata={"file_name": file_path.name}))

elif ext == ".txt":
# Reading text files as usual
with open(file_path) as f:
content = f.read()
documents.append(Document(text=content, metadata={"file_name": file_path.name}))

else:
log.warning(f"{ext} file is not supported for now. Skipping {file_path.name}")

return documents


def load_context(file_paths: List[str]) -> None:
global ov_chat_engine

# limit chat history to 3000 tokens
memory = ChatMemoryBuffer.from_defaults()

if not file_path:
if not file_paths:
ov_chat_engine = SimpleChatEngine.from_defaults(llm=ov_llm, system_prompt=chatbot_config["system_configuration"], memory=memory)
return

document = load_file(Path(file_path))
documents = load_files(file_paths)
Settings.embed_model = ov_embedding
index = VectorStoreIndex.from_documents([document])
index = VectorStoreIndex.from_documents(documents)
ov_chat_engine = index.as_chat_engine(llm=ov_llm, chat_mode=ChatMode.CONTEXT, system_prompt=chatbot_config["system_configuration"],
memory=memory)

Expand Down Expand Up @@ -201,7 +207,7 @@ def create_UI(initial_message: str, action_name: str) -> gr.Blocks:
gr.Markdown(chatbot_config["instructions"])

with gr.Row():
file_uploader_ui = gr.File(label="Additional context", file_types=[".pdf", ".txt"], scale=1)
file_uploader_ui = gr.Files(label="Additional context", file_types=[".pdf", ".txt"], scale=1)
with gr.Column(scale=4):
chatbot_ui = gr.Chatbot(value=[[None, initial_message]], label="Chatbot")
with gr.Row():
Expand Down

0 comments on commit 9f5ec19

Please sign in to comment.