From 0685205f3a50fdfc24085e3cce9e0471e0688496 Mon Sep 17 00:00:00 2001 From: Yash Pradhan Date: Sun, 17 Nov 2024 15:21:24 -0500 Subject: [PATCH 1/2] Added RAKE Keyword extraction and query search --- MinuteMate/back/main.py | 45 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/MinuteMate/back/main.py b/MinuteMate/back/main.py index a300dabb..54caf9d1 100644 --- a/MinuteMate/back/main.py +++ b/MinuteMate/back/main.py @@ -1,5 +1,22 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel +from weaviate.classes.query import Rerank, MetadataQuery +import os +import weaviate +from weaviate.classes.init import Auth +from weaviate.classes.init import AdditionalConfig, Timeout +from rake_nltk import Rake +import nltk +nltk.download('stopwords') +nltk.download('punkt') + + +weaviate_url = os.environ["WEAVIATE_URL"] +weaviate_api_key = os.environ["WEAVIATE_API_KEY"] +client = weaviate.connect_to_weaviate_cloud( + cluster_url=weaviate_url, + auth_credentials=Auth.api_key(weaviate_api_key), +) # Initialize the FastAPI app app = FastAPI( @@ -19,11 +36,10 @@ class PromptResponse(BaseModel): # Your Python processing logic def process_prompt(prompt: str) -> str: - response = '' - - #Call whatever code we need to here - - return response + rake = Rake() + rake.extract_keywords_from_text(prompt) + return rake.get_ranked_phrases()[:3] + # API endpoint @app.post("/process-prompt", response_model=PromptResponse) @@ -36,3 +52,22 @@ async def process_prompt_endpoint(request: PromptRequest): return PromptResponse(result=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + + + +# This is the code to get the collections based on top 3 keywords we are fetching from the RAKE code. +# You can add this code block below wherever you are configuring your API +collection = client.collections.get("MeetingDocument") +response = collection.query.bm25( + query=",".join(keywords), + limit=5, + # rerank=Rerank( + # prop="content", + # query="meeting" + # ), + # return_metadata=MetadataQuery(score=True) +) + +for o in response.objects: + print(o.properties) + # print(o.metadata.rerank_score) From 159ebf171b98064caf98ba6d920c36fa406269a5 Mon Sep 17 00:00:00 2001 From: neal logan Date: Mon, 18 Nov 2024 11:01:10 -0500 Subject: [PATCH 2/2] Added backend prompt/response functionality --- MinuteMate/back/main.py | 208 ++++++++++++++++++++++------ MinuteMate/back/methods/__init__.py | 0 MinuteMate/back/requirements.txt | 11 +- dev_notebooks/rag_prompt_dev.ipynb | 47 ++++--- 4 files changed, 202 insertions(+), 64 deletions(-) create mode 100644 MinuteMate/back/methods/__init__.py diff --git a/MinuteMate/back/main.py b/MinuteMate/back/main.py index 54caf9d1..43bb144a 100644 --- a/MinuteMate/back/main.py +++ b/MinuteMate/back/main.py @@ -1,22 +1,20 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel -from weaviate.classes.query import Rerank, MetadataQuery + import os + +from rake_nltk import Rake + import weaviate from weaviate.classes.init import Auth -from weaviate.classes.init import AdditionalConfig, Timeout -from rake_nltk import Rake -import nltk -nltk.download('stopwords') -nltk.download('punkt') +from weaviate.classes.init import AdditionalConfig +from weaviate.classes.init import Timeout +from weaviate.classes.query import Rerank +from weaviate.classes.query import MetadataQuery +import openai +from openai import OpenAI -weaviate_url = os.environ["WEAVIATE_URL"] -weaviate_api_key = os.environ["WEAVIATE_API_KEY"] -client = weaviate.connect_to_weaviate_cloud( - cluster_url=weaviate_url, - auth_credentials=Auth.api_key(weaviate_api_key), -) # Initialize the FastAPI app app = FastAPI( @@ -28,46 +26,174 @@ # Define the request schema class PromptRequest(BaseModel): - prompt: str + user_prompt_text: str # Define the response schema class PromptResponse(BaseModel): - response: str + generated_response: str + error_code : int + +# Takes a prompt from the front end, processes the prompt +# using NLP tools, embedding services, and generative services +# and finally returns the prompt response +def process_prompt(prompt_request: PromptRequest) -> PromptResponse: + + ### 0 - ENVIRONMENT AND CONFIGURATION ### + + # Update environment variables + # not sure if this works + # TODO test and/or look for alternatives + import os + + # Set API keys, endpoint URLs, model versions, and configurations + # Embedding and Generative Models + OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') + # OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL') + OPENAI_EMBEDDING_URL = os.getenv('OPENAI_EMBEDDING_URL') + OPENAI_GENERATION_URL = os.getenv('OPENAI_GENERATION_URL') + EMBEDDING_MODEL = 'text-embedding-3-small' + ENCODING_FORMAT = 'float' + RESPONDING_GENERATIVE_MODEL = 'gpt-4o' + # TRUSTSAFETY_GENERATIVE_MODEL = llama on modal, probably, but can't be too + + # API key, endpoint URL, and target collection(s) + # for Weaviate vector database + WEAVIATE_URL = os.environ['WEAVIATE_URL'] + WEAVIATE_API_KEY = os.environ['WEAVIATE_API_KEY'] + WEAVIATE_TARGET_COLLECTION = 'MeetingDocument' + # WEAVIATE_TARGET_COLLECTION = "VERBA_Embedding_text_embedding_3_small" -# Your Python processing logic -def process_prompt(prompt: str) -> str: - rake = Rake() - rake.extract_keywords_from_text(prompt) - return rake.get_ranked_phrases()[:3] + + + ### 1- INITIAL TRUST AND SAFETY CHECK ### + # TODO add initial trust & safety check here + # If trust and safety check fails, return the error immediately + + + + ### 2- INFORMATION RETRIEVAL ### + + # Set RAG search type + SEARCH_TYPE = 'keyword' + # SEARCH_TYPE = 'vector' + # SEARCH_TYPE = 'hybrid' + + # Establish connection with Weaviate server + # https://weaviate.io/developers/weaviate + weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=WEAVIATE_URL, + auth_credentials=Auth.api_key(WEAVIATE_API_KEY), + ) + # Look up the appropriate Weviate database collection + db_collection = weaviate_client.collections.get(WEAVIATE_TARGET_COLLECTION) + db_response = None + + # Extract keywords and query database + # TODO - finish and test + if(SEARCH_TYPE == 'keyword'): + + rake = Rake() + rake.extract_keywords_from_text(prompt_request.user_prompt_text) + keywords = rake.get_ranked_phrases()[:3] + db_response = db_collection.query.bm25( + query=",".join(keywords), + limit=5, + # rerank=Rerank( + # prop="content", + # query="meeting" + # ), + # return_metadata=MetadataQuery(score=True) + ) + + # Vectorize the prompt and query the database + # TODO - test + elif(SEARCH_TYPE == 'vector'): + + + # Set API Key. Not necessary if you have an + # OPENAI_API_KEY variable in your environment + openai.api_key = OPENAI_API_KEY + embedding_client = OpenAI() + + # Vector-embed the prompt + embedding_response = embedding_client.embeddings.create( + model = EMBEDDING_MODEL, + input = prompt_request.user_prompt_text, + encoding_format = ENCODING_FORMAT + ) + + # Extract the vector embeddings list[float] from the embedding response + query_vector = embedding_response.data[0].embedding + + # Send vector query to database and get response + db_response = db_collection.query.near_vector( + near_vector=query_vector, + limit=10, + return_metadata=MetadataQuery(distance=True) + ) + + #TODO support this + #elif(SEARCH_TYPE == 'hybrid'): + + + else: + #No RAG search + db_response = None + + # Extract items from database response + # and aggregate into a single string + db_response_text = "" + for item in db_response.objects: + segment = '\n\n' + db_response_text += segment + db_response_text += item.properties.get('content') + + + ### 3 - RESPONSE GENERATION ### + + # Generate response to user with OpenAI generative model + # https://platform.openai.com/docs/api-reference/chat/create + openai.api_key = OPENAI_API_KEY + generation_client = OpenAI() + generated_response_text = generation_client.chat.completions.create( + model="gpt-4o", + messages=[ + { + "role": "system", + "content": f"You are a helpful assistant who uses this context if appropriate: {db_response_text}" + }, + { + "role": "user", + "content": prompt_request.user_prompt_text + } + ] + ) + + ### 4 - FINAL TRUST AND SAFETY CHECK ### + # TODO add final trust & safety check here + # If trust and safety check fails, return an error + + + ### 5 - BUILD & RETURN RESPONSE OBJECT ### + # Return chat response to API layer + # to be passed along to frontend + prompt_response = PromptResponse() + prompt_response.generated_response = generated_response_text + return prompt_request + + + # API endpoint @app.post("/process-prompt", response_model=PromptResponse) -async def process_prompt_endpoint(request: PromptRequest): +async def process_prompt_endpoint(prompt_request: PromptRequest): """ Process the prompt and return the response """ try: - result = process_prompt(request.prompt) - return PromptResponse(result=result) + prompt_response = process_prompt(prompt_request) + return PromptResponse(result=prompt_response) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - - - -# This is the code to get the collections based on top 3 keywords we are fetching from the RAKE code. -# You can add this code block below wherever you are configuring your API -collection = client.collections.get("MeetingDocument") -response = collection.query.bm25( - query=",".join(keywords), - limit=5, - # rerank=Rerank( - # prop="content", - # query="meeting" - # ), - # return_metadata=MetadataQuery(score=True) -) - -for o in response.objects: - print(o.properties) - # print(o.metadata.rerank_score) + \ No newline at end of file diff --git a/MinuteMate/back/methods/__init__.py b/MinuteMate/back/methods/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/MinuteMate/back/requirements.txt b/MinuteMate/back/requirements.txt index 06eaddd2..1cfda1c3 100644 --- a/MinuteMate/back/requirements.txt +++ b/MinuteMate/back/requirements.txt @@ -6,4 +6,13 @@ fastapi[standard] pydantic # Necessary for web stuff -uvicorn \ No newline at end of file +uvicorn + +# vector database +weaviate-client==4.7.1 + +# environmental variables such as API keys and endpoint URLs +python-dotenv==1.0.0 + +# for running queries +openai==1.54.3 \ No newline at end of file diff --git a/dev_notebooks/rag_prompt_dev.ipynb b/dev_notebooks/rag_prompt_dev.ipynb index b026cb32..932bc266 100644 --- a/dev_notebooks/rag_prompt_dev.ipynb +++ b/dev_notebooks/rag_prompt_dev.ipynb @@ -151,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -200,7 +200,11 @@ "ENCODING_FORMAT = \"float\"\n", "\n", "# Database\n", - "COLLECTION_NAME = \"VERBA_Embedding_text_embedding_3_small\"" + "COLLECTION_NAME = \"VERBA_Embedding_text_embedding_3_small\"\n", + "\n", + "RESPONDING_GENERATIVE_MODEL = \"gpt-4o\"\n", + "\n", + "TRUSTSAFETY_GENERATIVE_MODEL = \"gpt-4o\"" ] }, { @@ -221,13 +225,13 @@ "from weaviate.classes.init import Auth, AdditionalConfig, Timeout\n", "import weaviate\n", "\n", - "client_db = weaviate.connect_to_weaviate_cloud(\n", + "weaviate_client = weaviate.connect_to_weaviate_cloud(\n", " cluster_url = WEAVIATE_URL,\n", " auth_credentials = Auth.api_key(WEAVIATE_API_KEY),\n", " additional_config=AdditionalConfig(timeout=Timeout(init=30, query=60, insert=120)) # Values in seconds\n", ")\n", "\n", - "print(client_db.is_ready())" + "print(weaviate_client.is_ready())" ] }, { @@ -236,7 +240,7 @@ "metadata": {}, "outputs": [], "source": [ - "# for collection in client_db.collections.list_all():\n", + "# for collection in weaviate_client.collections.list_all():\n", "# print(collection)" ] }, @@ -246,7 +250,7 @@ "metadata": {}, "outputs": [], "source": [ - "# client_db.close()" + "# weaviate_client.close()" ] }, { @@ -286,8 +290,7 @@ "# Set API Key. Not necessary if you have an \n", "# OPENAI_API_KEY variable in your environment\n", "openai.api_key = OPENAI_API_KEY \n", - "\n", - "client_embedding = OpenAI()" + "embedding_client = OpenAI()" ] }, { @@ -296,9 +299,9 @@ "metadata": {}, "outputs": [], "source": [ - "# print(type(client_embedding))\n", + "# print(type(embedding_client))\n", "\n", - "# for item in client_embedding.models.list():\n", + "# for item in embedding_client.models.list():\n", "# print(item)" ] }, @@ -343,24 +346,24 @@ "source": [ "from weaviate.classes.query import MetadataQuery\n", "\n", - "query_text = \"I'd like to know about issues with plumbing in or around 2024\"\n", + "user_prompt = \"I'd like to know about issues with plumbing in or around 2024\"\n", "\n", "# Vectorize the query\n", - "response_embedding = client_embedding.embeddings.create(\n", + "embedding_response = embedding_client.embeddings.create(\n", " model = EMBEDDING_MODEL,\n", - " input = query_text,\n", + " input = user_prompt,\n", " encoding_format = ENCODING_FORMAT\n", ")\n", "\n", - "# Extract the verctor embeddings list[float] from the embedding response\n", - "query_vector = openai_extract_vector(response_embedding) \n", + "# Extract the vector embeddings list[float] from the embedding response\n", + "vectorized_query = openai_extract_vector(embedding_response) \n", "\n", "# Look up the appropriate Weviate database collection - name based on embedding model used\n", - "collection = client_db.collections.get('VERBA_Embedding_text_embedding_3_small')\n", + "db_collection = weaviate_client.collections.get('VERBA_Embedding_text_embedding_3_small')\n", "\n", "# Send vector query to database and get response\n", - "db_response = collection.query.near_vector(\n", - " near_vector=query_vector,\n", + "db_response = db_collection.query.near_vector(\n", + " near_vector=vectorized_query,\n", " limit=10,\n", " return_metadata=MetadataQuery(distance=True)\n", ")\n", @@ -374,11 +377,11 @@ " db_response_content += item.properties.get('content')\n", "\n", "# # Print results\n", - "# for item in response_db.objects:\n", + "# for item in db_response.objects:\n", "# print(item.properties)\n", "# print(item.metadata.distance)\n", "\n", - "# print(response_content)" + "# print(db_response_content)" ] }, { @@ -392,7 +395,7 @@ "\n", "\n", "generation_response = generation_client.chat.completions.create(\n", - " model=\"gpt-4o\",\n", + " model=RESPONDING_GENERATIVE_MODEL,\n", " messages=[\n", " {\n", " \"role\": \"system\", \n", @@ -400,7 +403,7 @@ " },\n", " {\n", " \"role\": \"user\", \n", - " \"content\": query_text \n", + " \"content\": user_prompt \n", " }\n", " ]\n", ")\n",