Skip to content

Commit

Permalink
Added backend prompt/response functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
neal-logan committed Nov 18, 2024
1 parent 0685205 commit 159ebf1
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 64 deletions.
208 changes: 167 additions & 41 deletions MinuteMate/back/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from weaviate.classes.query import Rerank, MetadataQuery

import os

Check failure on line 4 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:4:8: F401 `os` imported but unused

Check failure on line 4 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:4:8: F401 `os` imported but unused

from rake_nltk import Rake

import weaviate
from weaviate.classes.init import Auth
from weaviate.classes.init import AdditionalConfig, Timeout
from rake_nltk import Rake
import nltk
nltk.download('stopwords')
nltk.download('punkt')
from weaviate.classes.init import AdditionalConfig

Check failure on line 10 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:10:35: F401 `weaviate.classes.init.AdditionalConfig` imported but unused

Check failure on line 10 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:10:35: F401 `weaviate.classes.init.AdditionalConfig` imported but unused
from weaviate.classes.init import Timeout

Check failure on line 11 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:11:35: F401 `weaviate.classes.init.Timeout` imported but unused

Check failure on line 11 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:11:35: F401 `weaviate.classes.init.Timeout` imported but unused
from weaviate.classes.query import Rerank

Check failure on line 12 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:12:36: F401 `weaviate.classes.query.Rerank` imported but unused

Check failure on line 12 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

MinuteMate/back/main.py:12:36: F401 `weaviate.classes.query.Rerank` imported but unused
from weaviate.classes.query import MetadataQuery

import openai
from openai import OpenAI

weaviate_url = os.environ["WEAVIATE_URL"]
weaviate_api_key = os.environ["WEAVIATE_API_KEY"]
client = weaviate.connect_to_weaviate_cloud(
cluster_url=weaviate_url,
auth_credentials=Auth.api_key(weaviate_api_key),
)

# Initialize the FastAPI app
app = FastAPI(
Expand All @@ -28,46 +26,174 @@

# Define the request schema
class PromptRequest(BaseModel):
prompt: str
user_prompt_text: str

# Define the response schema
class PromptResponse(BaseModel):
response: str
generated_response: str
error_code : int

# Takes a prompt from the front end, processes the prompt
# using NLP tools, embedding services, and generative services
# and finally returns the prompt response
def process_prompt(prompt_request: PromptRequest) -> PromptResponse:

### 0 - ENVIRONMENT AND CONFIGURATION ###

# Update environment variables
# not sure if this works
# TODO test and/or look for alternatives
import os

Check failure on line 46 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F811)

MinuteMate/back/main.py:46:12: F811 Redefinition of unused `os` from line 4

Check failure on line 46 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F811)

MinuteMate/back/main.py:46:12: F811 Redefinition of unused `os` from line 4

# Set API keys, endpoint URLs, model versions, and configurations
# Embedding and Generative Models
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
# OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL')
OPENAI_EMBEDDING_URL = os.getenv('OPENAI_EMBEDDING_URL')

Check failure on line 52 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

MinuteMate/back/main.py:52:5: F841 Local variable `OPENAI_EMBEDDING_URL` is assigned to but never used

Check failure on line 52 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

MinuteMate/back/main.py:52:5: F841 Local variable `OPENAI_EMBEDDING_URL` is assigned to but never used
OPENAI_GENERATION_URL = os.getenv('OPENAI_GENERATION_URL')

Check failure on line 53 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

MinuteMate/back/main.py:53:5: F841 Local variable `OPENAI_GENERATION_URL` is assigned to but never used

Check failure on line 53 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

MinuteMate/back/main.py:53:5: F841 Local variable `OPENAI_GENERATION_URL` is assigned to but never used
EMBEDDING_MODEL = 'text-embedding-3-small'
ENCODING_FORMAT = 'float'
RESPONDING_GENERATIVE_MODEL = 'gpt-4o'

Check failure on line 56 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

MinuteMate/back/main.py:56:5: F841 Local variable `RESPONDING_GENERATIVE_MODEL` is assigned to but never used

Check failure on line 56 in MinuteMate/back/main.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

MinuteMate/back/main.py:56:5: F841 Local variable `RESPONDING_GENERATIVE_MODEL` is assigned to but never used
# TRUSTSAFETY_GENERATIVE_MODEL = llama on modal, probably, but can't be too

# API key, endpoint URL, and target collection(s)
# for Weaviate vector database
WEAVIATE_URL = os.environ['WEAVIATE_URL']
WEAVIATE_API_KEY = os.environ['WEAVIATE_API_KEY']
WEAVIATE_TARGET_COLLECTION = 'MeetingDocument'
# WEAVIATE_TARGET_COLLECTION = "VERBA_Embedding_text_embedding_3_small"

# Your Python processing logic
def process_prompt(prompt: str) -> str:
rake = Rake()
rake.extract_keywords_from_text(prompt)
return rake.get_ranked_phrases()[:3]


### 1- INITIAL TRUST AND SAFETY CHECK ###
# TODO add initial trust & safety check here
# If trust and safety check fails, return the error immediately



### 2- INFORMATION RETRIEVAL ###

# Set RAG search type
SEARCH_TYPE = 'keyword'
# SEARCH_TYPE = 'vector'
# SEARCH_TYPE = 'hybrid'

# Establish connection with Weaviate server
# https://weaviate.io/developers/weaviate
weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
)
# Look up the appropriate Weviate database collection
db_collection = weaviate_client.collections.get(WEAVIATE_TARGET_COLLECTION)
db_response = None

# Extract keywords and query database
# TODO - finish and test
if(SEARCH_TYPE == 'keyword'):

rake = Rake()
rake.extract_keywords_from_text(prompt_request.user_prompt_text)
keywords = rake.get_ranked_phrases()[:3]
db_response = db_collection.query.bm25(
query=",".join(keywords),
limit=5,
# rerank=Rerank(
# prop="content",
# query="meeting"
# ),
# return_metadata=MetadataQuery(score=True)
)

# Vectorize the prompt and query the database
# TODO - test
elif(SEARCH_TYPE == 'vector'):


# Set API Key. Not necessary if you have an
# OPENAI_API_KEY variable in your environment
openai.api_key = OPENAI_API_KEY
embedding_client = OpenAI()

# Vector-embed the prompt
embedding_response = embedding_client.embeddings.create(
model = EMBEDDING_MODEL,
input = prompt_request.user_prompt_text,
encoding_format = ENCODING_FORMAT
)

# Extract the vector embeddings list[float] from the embedding response
query_vector = embedding_response.data[0].embedding

# Send vector query to database and get response
db_response = db_collection.query.near_vector(
near_vector=query_vector,
limit=10,
return_metadata=MetadataQuery(distance=True)
)

#TODO support this
#elif(SEARCH_TYPE == 'hybrid'):


else:
#No RAG search
db_response = None

# Extract items from database response
# and aggregate into a single string
db_response_text = ""
for item in db_response.objects:
segment = '\n<ContextSegment' + str(int(item.properties.get('chunk_id'))) + '>\n'
db_response_text += segment
db_response_text += item.properties.get('content')


### 3 - RESPONSE GENERATION ###

# Generate response to user with OpenAI generative model
# https://platform.openai.com/docs/api-reference/chat/create
openai.api_key = OPENAI_API_KEY
generation_client = OpenAI()
generated_response_text = generation_client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": f"You are a helpful assistant who uses this context if appropriate: {db_response_text}"
},
{
"role": "user",
"content": prompt_request.user_prompt_text
}
]
)

### 4 - FINAL TRUST AND SAFETY CHECK ###
# TODO add final trust & safety check here
# If trust and safety check fails, return an error


### 5 - BUILD & RETURN RESPONSE OBJECT ###
# Return chat response to API layer
# to be passed along to frontend
prompt_response = PromptResponse()
prompt_response.generated_response = generated_response_text
return prompt_request





# API endpoint
@app.post("/process-prompt", response_model=PromptResponse)
async def process_prompt_endpoint(request: PromptRequest):
async def process_prompt_endpoint(prompt_request: PromptRequest):
"""
Process the prompt and return the response
"""
try:
result = process_prompt(request.prompt)
return PromptResponse(result=result)
prompt_response = process_prompt(prompt_request)
return PromptResponse(result=prompt_response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))



# This is the code to get the collections based on top 3 keywords we are fetching from the RAKE code.
# You can add this code block below wherever you are configuring your API
collection = client.collections.get("MeetingDocument")
response = collection.query.bm25(
query=",".join(keywords),
limit=5,
# rerank=Rerank(
# prop="content",
# query="meeting"
# ),
# return_metadata=MetadataQuery(score=True)
)

for o in response.objects:
print(o.properties)
# print(o.metadata.rerank_score)

Empty file.
11 changes: 10 additions & 1 deletion MinuteMate/back/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,13 @@ fastapi[standard]
pydantic

# Necessary for web stuff
uvicorn
uvicorn

# vector database
weaviate-client==4.7.1

# environmental variables such as API keys and endpoint URLs
python-dotenv==1.0.0

# for running queries
openai==1.54.3
47 changes: 25 additions & 22 deletions dev_notebooks/rag_prompt_dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -200,7 +200,11 @@
"ENCODING_FORMAT = \"float\"\n",
"\n",
"# Database\n",
"COLLECTION_NAME = \"VERBA_Embedding_text_embedding_3_small\""
"COLLECTION_NAME = \"VERBA_Embedding_text_embedding_3_small\"\n",
"\n",
"RESPONDING_GENERATIVE_MODEL = \"gpt-4o\"\n",
"\n",
"TRUSTSAFETY_GENERATIVE_MODEL = \"gpt-4o\""
]
},
{
Expand All @@ -221,13 +225,13 @@
"from weaviate.classes.init import Auth, AdditionalConfig, Timeout\n",
"import weaviate\n",
"\n",
"client_db = weaviate.connect_to_weaviate_cloud(\n",
"weaviate_client = weaviate.connect_to_weaviate_cloud(\n",
" cluster_url = WEAVIATE_URL,\n",
" auth_credentials = Auth.api_key(WEAVIATE_API_KEY),\n",
" additional_config=AdditionalConfig(timeout=Timeout(init=30, query=60, insert=120)) # Values in seconds\n",
")\n",
"\n",
"print(client_db.is_ready())"
"print(weaviate_client.is_ready())"
]
},
{
Expand All @@ -236,7 +240,7 @@
"metadata": {},
"outputs": [],
"source": [
"# for collection in client_db.collections.list_all():\n",
"# for collection in weaviate_client.collections.list_all():\n",
"# print(collection)"
]
},
Expand All @@ -246,7 +250,7 @@
"metadata": {},
"outputs": [],
"source": [
"# client_db.close()"
"# weaviate_client.close()"
]
},
{
Expand Down Expand Up @@ -286,8 +290,7 @@
"# Set API Key. Not necessary if you have an \n",
"# OPENAI_API_KEY variable in your environment\n",
"openai.api_key = OPENAI_API_KEY \n",
"\n",
"client_embedding = OpenAI()"
"embedding_client = OpenAI()"
]
},
{
Expand All @@ -296,9 +299,9 @@
"metadata": {},
"outputs": [],
"source": [
"# print(type(client_embedding))\n",
"# print(type(embedding_client))\n",
"\n",
"# for item in client_embedding.models.list():\n",
"# for item in embedding_client.models.list():\n",
"# print(item)"
]
},
Expand Down Expand Up @@ -343,24 +346,24 @@
"source": [
"from weaviate.classes.query import MetadataQuery\n",
"\n",
"query_text = \"I'd like to know about issues with plumbing in or around 2024\"\n",
"user_prompt = \"I'd like to know about issues with plumbing in or around 2024\"\n",
"\n",
"# Vectorize the query\n",
"response_embedding = client_embedding.embeddings.create(\n",
"embedding_response = embedding_client.embeddings.create(\n",
" model = EMBEDDING_MODEL,\n",
" input = query_text,\n",
" input = user_prompt,\n",
" encoding_format = ENCODING_FORMAT\n",
")\n",
"\n",
"# Extract the verctor embeddings list[float] from the embedding response\n",
"query_vector = openai_extract_vector(response_embedding) \n",
"# Extract the vector embeddings list[float] from the embedding response\n",
"vectorized_query = openai_extract_vector(embedding_response) \n",
"\n",
"# Look up the appropriate Weviate database collection - name based on embedding model used\n",
"collection = client_db.collections.get('VERBA_Embedding_text_embedding_3_small')\n",
"db_collection = weaviate_client.collections.get('VERBA_Embedding_text_embedding_3_small')\n",
"\n",
"# Send vector query to database and get response\n",
"db_response = collection.query.near_vector(\n",
" near_vector=query_vector,\n",
"db_response = db_collection.query.near_vector(\n",
" near_vector=vectorized_query,\n",
" limit=10,\n",
" return_metadata=MetadataQuery(distance=True)\n",
")\n",
Expand All @@ -374,11 +377,11 @@
" db_response_content += item.properties.get('content')\n",
"\n",
"# # Print results\n",
"# for item in response_db.objects:\n",
"# for item in db_response.objects:\n",
"# print(item.properties)\n",
"# print(item.metadata.distance)\n",
"\n",
"# print(response_content)"
"# print(db_response_content)"
]
},
{
Expand All @@ -392,15 +395,15 @@
"\n",
"\n",
"generation_response = generation_client.chat.completions.create(\n",
" model=\"gpt-4o\",\n",
" model=RESPONDING_GENERATIVE_MODEL,\n",
" messages=[\n",
" {\n",
" \"role\": \"system\", \n",
" \"content\": f\"You are a helpful assistant who uses this context if appropriate: {db_response_content}\"\n",
" },\n",
" {\n",
" \"role\": \"user\", \n",
" \"content\": query_text \n",
" \"content\": user_prompt \n",
" }\n",
" ]\n",
")\n",
Expand Down

0 comments on commit 159ebf1

Please sign in to comment.