-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
270 lines (213 loc) · 8.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from fastapi import FastAPI, HTTPException,UploadFile,File
import os
from dotenv import load_dotenv
import google.generativeai as genai
import io
import fitz
from PIL import Image
from io import BytesIO
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
# import faiss
# import numpy as np
# from transformers import AutoModel, AutoTokenizer
# Load environment variables
load_dotenv()
# FastAPI instance
app = FastAPI()
origins = [
"http://localhost",
"http://localhost:3000",
"http://localhost:5173/",
"*"
]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # List of allowed origins
allow_credentials=True, # Allow cookies and credentials
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
# Set the API key for google.generativeai explicitly
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
else:
raise RuntimeError("Google API Key is not set in environment variables.")
model = genai.GenerativeModel('gemini-1.5-flash')
# In-memory store for extracted document text
uploaded_docs = {}
chat_history = []
def extract_text_from_pdf(file_io):
text = ""
pdf_document = fitz.open(stream=file_io, filetype="pdf")
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
text += page.get_text()
pdf_document.close()
return text
def generate_gemini_response(prompt):
chat_history.append(prompt)
context = "\n".join(chat_history)
formatted_prompt = (
f"{context}\n\n"
"Please respond to the following prompt in Markdown format, using bullet points, lists, tables, or headers as needed. "
"Ensure that the response is concise, relevant, and structured for easy readability.\n\n"
f"**Prompt:** {prompt}\n\n"
"**Instructions:**\n"
"- Provide direct, well-structured answers to the question.\n"
"- Use Markdown elements like headers (`#`), bullet points (`-`), tables, or numbered lists to organize information.\n"
"- Focus on accuracy and clarity, answering based on the context of previous messages.\n\n"
"Return only the Markdown response text."
)
response = model.generate_content(
formatted_prompt,
generation_config=genai.types.GenerationConfig(
candidate_count=1,
stop_sequences=[],
temperature=0.5,
# max_output_tokens=1000,
)
)
chat_history.append(response.text.strip())
return response.text.strip()
def generate_document_response(prompt, document_text):
prompt = (
f"## Document Text\n"
f"{document_text}\n\n"
f"## Question\n"
f"{prompt}\n\n"
f"---\n\n"
f"**Response:**\n"
f"Based on the document, provide an answer to the above question as accurately and concisely as possible. "
f"Use only information that directly addresses the question without extra details."
)
chat_history.append(prompt)
context = "\n".join(chat_history)
# Combine the document content with the question for context
response = model.generate_content(
context,
generation_config=genai.types.GenerationConfig(
candidate_count=1,
stop_sequences=[],
# max_output_tokens=10000,
temperature=0.5,
)
)
chat_history.append(response.text)
suggestive_prompt = (
f"Based on the document text below:\n\n{document_text}\n\n"
f"Generate the three most relevant follow-up questions for the following question: '{prompt}'."
f"These questions should help the user explore important details, clarify key points, or deepen understanding."
f"Only output the questions in the following format:\n\n"
f"1. [First question]\n"
f"2. [Second question]\n"
f"3. [Third question]\n\n"
f"Do not include any additional text or explanations."
)
suggestions = model.generate_content(suggestive_prompt, generation_config=genai.types.GenerationConfig(
candidate_count=1,
stop_sequences=[],
# max_output_tokens=200,
temperature=0.5,
))
suggested_questions = suggestions.text.split("\n") if suggestions.text else []
return {
"response": response.text,
"suggested_questions": suggested_questions
}
def upload_file_to_gemini(sourceFile):
response = genai.upload_file(sourceFile, mime_type="application/pdf")
# Check for the presence of an ID or other unique identifier
if hasattr(response, 'id'):
return response.id
elif hasattr(response, 'name'):
return response.name
else:
raise ValueError("File upload response does not contain a file ID or unique identifier.")
def multimodal_search(image_data, prompt):
image_r = Image.open(BytesIO(image_data))
prompt__ = (
f"Analyze the provided image in detail. Identify key visual elements, context, and relevant patterns "
f"or information that could answer the question below as accurately as possible:\n\n"
f"Image Analysis:\n\n"
f"Based on the image content, please answer the following question:\n\n"
f"**Question:** {prompt}\n\n"
f"**Instructions:** Provide a concise, well-structured answer in Markdown format, including "
f"bullet points, lists, or tables if they help to present the information clearly. Respond directly "
f"to the question, focusing on visual evidence from the image."
)
response = model.generate_content([prompt__, image_r])
return response.text
class PromptRequest(BaseModel):
prompt: str
class DocumentSearchRequest(BaseModel):
file_id: str
question: str
# API ROUTES
@app.post("/text-search/")
async def generate_response(request: PromptRequest):
try:
response = generate_gemini_response(request.prompt)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post('/upload-source-file/')
async def upload_source_file(sourceFile: UploadFile = File(...)):
try:
file_bytes = await sourceFile.read()
# Extract text from the PDF
file_io = io.BytesIO(file_bytes)
document_text = extract_text_from_pdf(file_io)
# Store the document text with a unique ID
file_id = sourceFile.filename
uploaded_docs[file_id] = document_text
return {"message": "File uploaded successfully", "file_id": file_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/document-search/")
async def ask_question(request: DocumentSearchRequest):
try:
# Retrieve the document text from in-memory storage
document_text = uploaded_docs.get(request.file_id)
if not document_text:
raise HTTPException(status_code=404, detail="Document not found.")
# Generate a response based on the document text and question
response = generate_document_response(request.question, document_text)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/list-documents/')
async def list_documents():
try:
# List document IDs and count the total number
document_ids = list(uploaded_docs.keys())
total_documents = len(document_ids)
return {"total_documents": total_documents, "document_ids": document_ids}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete('/delete-document/{file_id}')
async def delete_document(file_id: str):
try:
# Check if the document exists in the dictionary
if file_id in uploaded_docs:
# Delete the document from the dictionary
del uploaded_docs[file_id]
return {"message": f"Document with ID '{file_id}' has been deleted successfully."}
else:
# Document not found
raise HTTPException(status_code=404, detail="Document not found.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post('/multimodal-search/')
async def multimodal_search_endpoint(image: UploadFile = File(...), prompt: str = None):
try:
image_data = await image.read()
response = multimodal_search(image_data, prompt)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/')
async def root():
return {"message": "Welcome to the Gemini API!"}