From 282aecb0f4de78099cde37c683d1f0c8007cce12 Mon Sep 17 00:00:00 2001 From: Yash Pradhan Date: Thu, 21 Nov 2024 07:45:30 -0500 Subject: [PATCH] API for processing input from user working, openAI issue remaining, use main02.py --- MinuteMate/back/main02.py | 242 ++++++++++++++++++++++++++++++++++++++ MinuteMate/front/app.py | 41 +++++-- 2 files changed, 272 insertions(+), 11 deletions(-) create mode 100644 MinuteMate/back/main02.py diff --git a/MinuteMate/back/main02.py b/MinuteMate/back/main02.py new file mode 100644 index 00000000..f87caef7 --- /dev/null +++ b/MinuteMate/back/main02.py @@ -0,0 +1,242 @@ +import os +import logging +from typing import Optional, List + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field + +import weaviate +from weaviate.classes.init import Auth +from weaviate.classes.query import Rerank, MetadataQuery + +import openai +from openai import OpenAI + + +from rake_nltk import Rake +from dotenv import load_dotenv + +import nltk +import ssl + +try: + _create_unverified_https_context = ssl._create_unverified_context +except AttributeError: + pass +else: + ssl._create_default_https_context = _create_unverified_https_context + + +try: + nltk.download('punkt') + nltk.download('punkt_tab') + nltk.download('stopwords') +except Exception as e: + print(f"Error downloading NLTK resources: {e}") + + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +load_dotenv() + + +# Initialize the FastAPI app +app = FastAPI( + title="MinuteMate Prompt & Response API", + description="An AI-powered API for processing meeting-related prompts", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +# Define request and response models +class PromptRequest(BaseModel): + user_prompt_text: str = Field(..., min_length=1, max_length=1000) + +class ContextSegment(BaseModel): + chunk_id: int + content: str + score: Optional[float] = None + +class PromptResponse(BaseModel): + generated_response: str + context_segments: List[ContextSegment] = [] + keywords: List[str] = [] + error_code: int = 0 + +class WeaviateConfig: + """Configuration for Weaviate connection and querying""" + SEARCH_TYPES = { + 'keyword': 'bm25', + 'vector': 'near_vector', + 'hybrid': 'hybrid' + } + + @classmethod + def get_weaviate_client(cls, url: str, api_key: str): + """Establish Weaviate connection""" + try: + return weaviate.connect_to_weaviate_cloud( + cluster_url=url, + auth_credentials=Auth.api_key(api_key), + additional_config=weaviate.classes.init.AdditionalConfig( + timeout=weaviate.classes.init.Timeout(init=10, query=30) + ) + ) + except Exception as e: + logger.error(f"Weaviate connection error: {e}") + raise + +class PromptProcessor: + """Main class for processing user prompts""" + def __init__(self): + # Load environment variables + self.load_env_vars() + + # Initialize clients + self.weaviate_client = WeaviateConfig.get_weaviate_client( + self.WEAVIATE_URL, + self.WEAVIATE_API_KEY + ) + self.openai_client = OpenAI(api_key=self.OPENAI_API_KEY) + + def load_env_vars(self): + """Load and validate environment variables""" + required_vars = [ + 'OPENAI_API_KEY', + 'WEAVIATE_URL', + 'WEAVIATE_API_KEY' + ] + + for var in required_vars: + value = os.getenv(var) + print(f"Loading {var}: {value}") + if not value: + raise ValueError(f"Missing environment variable: {var}") + setattr(self, var, value) + + def extract_keywords(self, text: str) -> List[str]: + """Extract keywords using RAKE""" + try: + + rake = Rake() + rake.extract_keywords_from_text(text) + return rake.get_ranked_phrases()[:3] + except Exception as e: + logger.error(f"Keyword extraction error: {e}") + return [] + + def search_weaviate(self, query: str, search_type: str = 'keyword') -> List[ContextSegment]: + """Perform search in Weaviate database""" + try: + collection = self.weaviate_client.collections.get('MeetingDocument') + + if search_type == 'keyword': + keywords = self.extract_keywords(query) + results = collection.query.bm25( + query=",".join(keywords), + limit=5 + ) + print(keywords) + elif search_type == 'vector': + embedding = self.openai_client.embeddings.create( + model='text-embedding-3-small', + input=query + ).data[0].embedding + + results = collection.query.near_vector( + near_vector=embedding, + limit=5 + ) + else: + raise ValueError(f"Unsupported search type: {search_type}") + + context_segments = [ + ContextSegment( + chunk_id=int(item.properties.get('chunk_id', 0)), + content=item.properties.get('content', ''), + score=getattr(item.metadata, 'distance', None) + ) for item in results.objects + ] + return context_segments, keywords + except Exception as e: + logger.error(f"Weaviate search error: {e}") + return [] + + def generate_response(self, prompt: str, context_segments: List[ContextSegment]) -> str: + """Generate response using OpenAI""" + context_text = "\n".join([ + f"\n{seg.content}" + for seg in context_segments + ]) + + try: + response = self.openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + { + "role": "system", + "content": f"Use this context if relevant: {context_text}" + }, + { + "role": "user", + "content": prompt + } + ] + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"OpenAI generation error: {e}") + return "I'm sorry, but I couldn't generate a response." + + def process_prompt(self, prompt_request: PromptRequest) -> PromptResponse: + """Main method to process user prompt""" + try: + # Search for relevant context + context_segments, keywords = self.search_weaviate(prompt_request.user_prompt_text) + + # Generate response + generated_response = self.generate_response( + prompt_request.user_prompt_text, + context_segments + ) + + return PromptResponse( + generated_response=generated_response, + context_segments=context_segments, + keywords = keywords, + error_code=0 + ) + + except Exception as e: + logger.error(f"Prompt processing error: {e}") + return PromptResponse( + generated_response="An error occurred while processing your request.", + error_code=500 + ) + +# Initialize processor +processor = PromptProcessor() + +# API Endpoint +@app.post("/process-prompt", response_model=PromptResponse) +async def process_prompt_endpoint(prompt_request: PromptRequest): + """Process user prompt and return response""" + return processor.process_prompt(prompt_request) + + +# Cleanup on shutdown +@app.on_event("shutdown") +async def shutdown_event(): + """Close Weaviate connection on app shutdown""" + processor.weaviate_client.close() \ No newline at end of file diff --git a/MinuteMate/front/app.py b/MinuteMate/front/app.py index dedd457e..90b5f2d2 100644 --- a/MinuteMate/front/app.py +++ b/MinuteMate/front/app.py @@ -2,6 +2,7 @@ import streamlit as st import time import os +import requests from dotenv import load_dotenv @@ -142,20 +143,38 @@ def display_chat_messages() -> None: elif button_cols_2[2].button(example_prompts[5], help=example_prompts_help[5]): button_pressed = example_prompts[5] + + if prompt := (st.chat_input("Type your prompt") or button_pressed): with st.chat_message("user"): st.markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) - response = f"Searching for: {prompt}. Please wait..." - with st.chat_message("assistant"): - message_placeholder = st.empty() - full_response = "" - for chunk in response.split(): - full_response += chunk + " " - time.sleep(0.05) - message_placeholder.markdown(full_response + "▌") - message_placeholder.markdown(full_response) - - st.session_state.messages.append({"role": "assistant", "content": full_response}) + try: + # Make API call to backend + response = requests.post( + "http://localhost:8000/process-prompt", # Adjust URL as needed + json={"user_prompt_text": prompt} + ) + + # Check if request was successful + if response.status_code == 200: + # Extract the generated response + generated_response = response.json().get('generated_response', 'No response generated') + respo = response.json() + + # Display the response + with st.chat_message("assistant"): + st.markdown(respo) + + # Add to message history + st.session_state.messages.append({ + "role": "assistant", + "content": generated_response + }) + else: + st.error(f"API Error: {response.text}") + + except requests.RequestException as e: + st.error(f"Connection error: {e}")