From 2add1c1732e852a025da9ba87e64db509b5e9c32 Mon Sep 17 00:00:00 2001 From: potthoffjan Date: Wed, 10 Jul 2024 09:36:26 +0200 Subject: [PATCH 1/2] History aware Retriever and multiple LLM answers and basic in/out Signed-off-by: potthoffjan --- .../RAG/LangChain_Implementation/chain.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 src/backend/RAG/LangChain_Implementation/chain.py diff --git a/src/backend/RAG/LangChain_Implementation/chain.py b/src/backend/RAG/LangChain_Implementation/chain.py new file mode 100644 index 0000000..b84c75a --- /dev/null +++ b/src/backend/RAG/LangChain_Implementation/chain.py @@ -0,0 +1,209 @@ +import json +import os +import sys + + +from dotenv import load_dotenv + +from astrapy import DataAPIClient +from astrapy.db import AstraDB +from langchain_astradb import AstraDBVectorStore + +#from langchain.embeddings import OpenAIEmbeddings +#from langchain_community.embeddings import OpenAIEmbeddings +from langchain_openai import OpenAIEmbeddings +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +#from langchain_community.embeddings import HuggingFaceEmbeddings +#from langchain_huggingface import HuggingFaceEmbeddings +#from langchain_community.llms.openai import OpenAI +from langchain_openai import OpenAI +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_community.vectorstores.chroma import Chroma + +from langchain_core.prompts import ChatPromptTemplate +from langchain.chains import create_history_aware_retriever +from langchain_core.prompts import MessagesPlaceholder +from langchain.chains import create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain + +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.messages import HumanMessage, AIMessage +from langchain_anthropic import AnthropicLLM +from langchain_anthropic import ChatAnthropic + +""" +Script takes in args: +0) list of LLMs with which to retrieve e.g. ['gpt-4', 'gemini', 'mistral'] +1) input string +2) chat history in following shape [ + {'gpt-4': "Hello, how can I help you?"}, + {'user': "What do prisons and plants have in common?"} +] +""" + +def custom_history(entire_history:list, llm_name:str): + chat_history = [] + for msg in entire_history: + if 'user' in msg: + chat_history.extend([HumanMessage(content=msg['user'])]) + if llm_name in msg: + chat_history.extend([AIMessage(content=msg[llm_name])]) + return chat_history + + +def main(): + + if len(sys.argv) < 3: + print("""Error: Please provide: + 1) [list of LLM models to use] + (['gpt-4', 'gemini', 'claude']) + 2) 'input string' + 3) [{chat history}] in the following shape: + [{'gpt-4': "Hello, how can I help you?"}, + {'user': "What do prisons and plants have in common?"} + etc.]""") + + # Arguments + llm_list = sys.argv[1] + llm_list = list(llm_list.replace('[', '').replace(']', '').replace("'", '').split(',')) + if not llm_list: + llm_list = ['gpt-4'] + #print(llm_list) + input_string = sys.argv[2] + #print(input_string) + message_history = sys.argv[3] + #print(message_history) + message_history = message_history.split(';;') + #print(message_history) + message_history = [json.loads(substring.replace("'", '"')) for substring in message_history] + #print(message_history) + + load_dotenv() + + # to be put into seperate function in order to invoke LLMs seperately + openai_api_key = os.environ.get('OPENAI_API_KEY') + GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') + ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") + + test_llm_list = ['gpt-4'] + #llm_list = test_llm_list + test_history = [ + {'gpt-4': "Hello, how can I help you?", + 'gemini': "Hello, how can I help you?"}, + {'user': "What do prisons and plants have in common?"}, + {'gpt-4': "They both have cell walls.", + 'gemini': "They have cell walls."}, + ] + # message_history = test_history + + test_query = "Ah, true. Thanks. What else do they have in common?" + # test_query = "How many corners does a heptagon have?" + # input_string = test_query + # test_follow_up = "How does one call a polygon with two more corners?" + + # AstraDB Section + ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT') + ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') + ASTRA_DB_NAMESPACE = 'test' + ASTRA_DB_COLLECTION = 'test_collection_2' + + # LangChain Docs: ------------------------- + vstore = AstraDBVectorStore( + embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), + collection_name=ASTRA_DB_COLLECTION, + api_endpoint=ASTRA_DB_API_ENDPOINT, + token=ASTRA_DB_APPLICATION_TOKEN, + namespace=ASTRA_DB_NAMESPACE, + ) + # ------------------------------------------ + + # For test purposes: ----------------------- + # import bs4 + # from langchain_chroma import Chroma + # from langchain_community.document_loaders import WebBaseLoader + + # loader = WebBaseLoader( + # web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",), + # bs_kwargs=dict( + # parse_only=bs4.SoupStrainer( + # class_=("post-content", "post-title", "post-header") + # ) + # ), + # ) + # docs = loader.load() + + # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) + # splits = text_splitter.split_documents(docs) + # vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings()) + # retriever = vectorstore.as_retriever() + # # test end ---------------------------------- + + retriever = vstore.as_retriever(search_kwargs={"k": 3}) + + contextualize_q_system_prompt = """Given a chat history and the latest user question \ + which might reference context in the chat history, formulate a standalone question \ + which can be understood without the chat history. Do NOT answer the question, \ + just reformulate it if needed and otherwise return it as is.""" + + qa_system_prompt = """You are an assistant for question-answering tasks. \ + Use the following pieces of retrieved context to answer the question. \ + If you don't know the answer, just say that you don't know. \ + Use three sentences maximum and keep the answer concise.\ + + {context}""" + + answers = {} + for _llm in llm_list: + #print(_llm) + chat_history = custom_history(message_history, _llm) + if _llm == 'gpt-4': + llm = OpenAI(temperature=0.2) + elif _llm == 'gemini': + llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest") + elif _llm == 'claude': + llm = ChatAnthropic(model_name="claude-3-opus-20240229") + + print(chat_history) + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + llm, retriever, contextualize_q_prompt + ) + + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", qa_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) + ### Answer question ### + rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + msg = rag_chain.invoke({"input": input_string, "chat_history": chat_history}) + answers[_llm] = msg['answer'] + print(msg['answer']) + #print(answers) + + # chat_history.extend([HumanMessage(content=input_string), AIMessage(content=msg_1["answer"])]) + # print(msg_1['input']) + # print(msg_1['answer']) + # print(chat_history) + # msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history}) + # chat_history.extend([HumanMessage(content=test_follow_up), AIMessage(content=msg_2["answer"])]) + # print(msg_2['input']) + # print(msg_2['answer']) + # print(chat_history) + return answers + +if __name__ == "__main__": + main() + \ No newline at end of file From 1f8970bae5e4088b1328892523f1ef2d4ae6947c Mon Sep 17 00:00:00 2001 From: potthoffjan Date: Wed, 10 Jul 2024 10:20:25 +0200 Subject: [PATCH 2/2] annoying linting fixed --- .../RAG/LangChain_Implementation/chain.py | 177 ++++++++---------- 1 file changed, 80 insertions(+), 97 deletions(-) diff --git a/src/backend/RAG/LangChain_Implementation/chain.py b/src/backend/RAG/LangChain_Implementation/chain.py index b84c75a..005fc36 100644 --- a/src/backend/RAG/LangChain_Implementation/chain.py +++ b/src/backend/RAG/LangChain_Implementation/chain.py @@ -2,40 +2,24 @@ import os import sys - from dotenv import load_dotenv - -from astrapy import DataAPIClient -from astrapy.db import AstraDB +from langchain.chains import create_history_aware_retriever, create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_anthropic import ChatAnthropic from langchain_astradb import AstraDBVectorStore - -#from langchain.embeddings import OpenAIEmbeddings -#from langchain_community.embeddings import OpenAIEmbeddings -from langchain_openai import OpenAIEmbeddings -from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter -#from langchain_community.embeddings import HuggingFaceEmbeddings -#from langchain_huggingface import HuggingFaceEmbeddings -#from langchain_community.llms.openai import OpenAI -from langchain_openai import OpenAI +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_google_genai import ChatGoogleGenerativeAI -from langchain_community.vectorstores.chroma import Chroma - -from langchain_core.prompts import ChatPromptTemplate -from langchain.chains import create_history_aware_retriever -from langchain_core.prompts import MessagesPlaceholder -from langchain.chains import create_retrieval_chain -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain_community.chat_message_histories import ChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_core.messages import HumanMessage, AIMessage -from langchain_anthropic import AnthropicLLM -from langchain_anthropic import ChatAnthropic +# from langchain.embeddings import OpenAIEmbeddings +# from langchain_community.embeddings import OpenAIEmbeddings +# from langchain_community.embeddings import HuggingFaceEmbeddings +# from langchain_huggingface import HuggingFaceEmbeddings +# from langchain_community.llms.openai import OpenAI +from langchain_openai import OpenAI, OpenAIEmbeddings """ -Script takes in args: +Script takes in args: 0) list of LLMs with which to retrieve e.g. ['gpt-4', 'gemini', 'mistral'] 1) input string 2) chat history in following shape [ @@ -44,87 +28,85 @@ ] """ -def custom_history(entire_history:list, llm_name:str): + +def custom_history(entire_history: list, llm_name: str): chat_history = [] for msg in entire_history: - if 'user' in msg: + if 'user' in msg: chat_history.extend([HumanMessage(content=msg['user'])]) if llm_name in msg: chat_history.extend([AIMessage(content=msg[llm_name])]) return chat_history - + def main(): - if len(sys.argv) < 3: - print("""Error: Please provide: + print("""Error: Please provide: 1) [list of LLM models to use] (['gpt-4', 'gemini', 'claude']) 2) 'input string' - 3) [{chat history}] in the following shape: + 3) [{chat history}] in the following shape: [{'gpt-4': "Hello, how can I help you?"}, {'user': "What do prisons and plants have in common?"} etc.]""") - - # Arguments + + # Arguments llm_list = sys.argv[1] llm_list = list(llm_list.replace('[', '').replace(']', '').replace("'", '').split(',')) if not llm_list: llm_list = ['gpt-4'] - #print(llm_list) + # print(llm_list) input_string = sys.argv[2] - #print(input_string) + # print(input_string) message_history = sys.argv[3] - #print(message_history) + # print(message_history) message_history = message_history.split(';;') - #print(message_history) + # print(message_history) message_history = [json.loads(substring.replace("'", '"')) for substring in message_history] - #print(message_history) - - load_dotenv() - - # to be put into seperate function in order to invoke LLMs seperately + # print(message_history) + + load_dotenv() + + # to be put into seperate function in order to invoke LLMs seperately openai_api_key = os.environ.get('OPENAI_API_KEY') - GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') - ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") - - test_llm_list = ['gpt-4'] - #llm_list = test_llm_list - test_history = [ - {'gpt-4': "Hello, how can I help you?", - 'gemini': "Hello, how can I help you?"}, - {'user': "What do prisons and plants have in common?"}, - {'gpt-4': "They both have cell walls.", - 'gemini': "They have cell walls."}, - ] + # google_api_key = os.environ.get('GOOGLE_API_KEY') + # anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY') + + # test_llm_list = ['gpt-4'] + # llm_list = test_llm_list + # test_history = [ + # {'gpt-4': 'Hello, how can I help you?', 'gemini': 'Hello, how can I help you?'}, + # {'user': 'What do prisons and plants have in common?'}, + # {'gpt-4': 'They both have cell walls.', 'gemini': 'They have cell walls.'}, + # ] # message_history = test_history - - test_query = "Ah, true. Thanks. What else do they have in common?" + + # test_query = 'Ah, true. Thanks. What else do they have in common?' # test_query = "How many corners does a heptagon have?" # input_string = test_query # test_follow_up = "How does one call a polygon with two more corners?" - - # AstraDB Section - ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT') - ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') - ASTRA_DB_NAMESPACE = 'test' - ASTRA_DB_COLLECTION = 'test_collection_2' - + + # AstraDB Section + astra_db_api_endpoint = os.environ.get('ASTRA_DB_API_ENDPOINT') + astra_db_application_token = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') + astra_db_namespace = 'test' + astra_db_collection = 'test_collection_2' + # LangChain Docs: ------------------------- vstore = AstraDBVectorStore( - embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), - collection_name=ASTRA_DB_COLLECTION, - api_endpoint=ASTRA_DB_API_ENDPOINT, - token=ASTRA_DB_APPLICATION_TOKEN, - namespace=ASTRA_DB_NAMESPACE, + embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), + collection_name=astra_db_collection, + api_endpoint=astra_db_api_endpoint, + token=astra_db_application_token, + namespace=astra_db_namespace, ) # ------------------------------------------ - + # For test purposes: ----------------------- # import bs4 # from langchain_chroma import Chroma # from langchain_community.document_loaders import WebBaseLoader - + # loader = WebBaseLoader( # web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",), # bs_kwargs=dict( @@ -140,70 +122,71 @@ def main(): # vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings()) # retriever = vectorstore.as_retriever() # # test end ---------------------------------- - - retriever = vstore.as_retriever(search_kwargs={"k": 3}) - + + retriever = vstore.as_retriever(search_kwargs={'k': 3}) + contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is.""" - + qa_system_prompt = """You are an assistant for question-answering tasks. \ Use the following pieces of retrieved context to answer the question. \ If you don't know the answer, just say that you don't know. \ Use three sentences maximum and keep the answer concise.\ {context}""" - + answers = {} - for _llm in llm_list: - #print(_llm) + for _llm in llm_list: + # print(_llm) chat_history = custom_history(message_history, _llm) if _llm == 'gpt-4': llm = OpenAI(temperature=0.2) elif _llm == 'gemini': - llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest") + llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro-latest') elif _llm == 'claude': - llm = ChatAnthropic(model_name="claude-3-opus-20240229") - + llm = ChatAnthropic(model_name='claude-3-opus-20240229') + print(chat_history) contextualize_q_prompt = ChatPromptTemplate.from_messages( [ - ("system", contextualize_q_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), + ('system', contextualize_q_system_prompt), + MessagesPlaceholder('chat_history'), + ('human', '{input}'), ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) - + qa_prompt = ChatPromptTemplate.from_messages( [ - ("system", qa_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{input}"), + ('system', qa_system_prompt), + MessagesPlaceholder('chat_history'), + ('human', '{input}'), ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) ### Answer question ### rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) - msg = rag_chain.invoke({"input": input_string, "chat_history": chat_history}) + msg = rag_chain.invoke({'input': input_string, 'chat_history': chat_history}) answers[_llm] = msg['answer'] print(msg['answer']) - #print(answers) - + # print(answers) + # chat_history.extend([HumanMessage(content=input_string), AIMessage(content=msg_1["answer"])]) # print(msg_1['input']) # print(msg_1['answer']) # print(chat_history) - # msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history}) - # chat_history.extend([HumanMessage(content=test_follow_up), AIMessage(content=msg_2["answer"])]) + # msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history}) + # chat_history.extend([HumanMessage(content=test_follow_up), + # AIMessage(content=msg_2["answer"])]) # print(msg_2['input']) # print(msg_2['answer']) # print(chat_history) return answers -if __name__ == "__main__": + +if __name__ == '__main__': main() - \ No newline at end of file