-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
143 lines (117 loc) · 4.02 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import time
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
from langchain_text_splitters.character import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from pinecone import Pinecone
from pinecone import ServerlessSpec
from langchain.chains.question_answering import load_qa_chain
from langchain_pinecone import PineconeVectorStore
from langchain_groq import ChatGroq
import warnings
warnings.filterwarnings('ignore')
load_dotenv()
# Load environment variables
api_key = os.getenv("PINECONE_API_KEY")
if not api_key:
raise ValueError("PINECONE_API_KEY is not set in the .env file.")
# Initialize Pinecone
pc = Pinecone(api_key=api_key)
index_name = "budget"
# FastAPI app
app = FastAPI()
# Initialize LLM
llm = ChatGroq(
model="llama-3.1-70b-versatile",
temperature=0
)
chain = load_qa_chain(llm, chain_type="stuff")
# Load documents
def read_doc(directory):
file_loader = PyPDFDirectoryLoader(directory)
documents = file_loader.load()
return documents
doc = read_doc("uploads/")
# Chunk documents
def chunk_data(docs, chunk_size=800, chunk_overlap=80):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunked_docs = text_splitter.split_documents(docs)
return chunked_docs
chunked_docs = chunk_data(doc)
print(f" The length of chunked data is {len(chunked_docs)}")
# Embed documents
embeddings = HuggingFaceEmbeddings()
embedding_dim = 768
existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
print(f"The names of existing indexes are {existing_indexes}")
### Create Pinecone Index
if index_name not in existing_indexes:
pc.create_index(
name=index_name,
dimension=embedding_dim,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
)
index = PineconeVectorStore.from_documents(
chunked_docs,
index_name=index_name,
embedding=embeddings
)
while not pc.describe_index(index_name).status["ready"]:
time.sleep(1)
## Connect to Index
index = pc.Index(index_name)
print(f"The index name is {index}")
index = PineconeVectorStore.from_existing_index(
index_name=index_name,
embedding=embeddings
)
print(f"The inex is {index}")
# API input model
class Query(BaseModel):
query: str
top_k: int = 2
# Function to retrieve matching documents
def retrieve_query(query: str, k: int = 2):
matching_result = index.similarity_search(query=query, k=k)
return matching_result
# Function to get an answer
def retrieve_answer(query: str):
try:
doc_search = retrieve_query(query)
answer = chain.run(input_documents=doc_search, question=query)
return answer
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# FastAPI endpoints
@app.post("/query")
async def query_documents(query: Query):
"""
Query the document index and retrieve an answer.
"""
answer = retrieve_answer(query.query)
print(answer)
return {"query": query.query, "answer": answer}
def streamlit_ui():
st.title("Indian Budget Expert")
st.write("I am your Budget Buddy to help you understand the Budget")
query_input = st.text_input("Enter your query:")
if st.button("Get Answer"):
if query_input.strip():
with st.spinner("Retrieving answer..."):
try:
answer = retrieve_answer(query_input)
st.success(f"Answer: {answer}")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
else:
st.warning("Please enter a query.")
if __name__ == "__main__":
streamlit_ui()