Skip to content

Commit

Permalink
annoying linting fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
potthoffjan committed Jul 10, 2024
1 parent 2add1c1 commit 1f8970b
Showing 1 changed file with 80 additions and 97 deletions.
177 changes: 80 additions & 97 deletions src/backend/RAG/LangChain_Implementation/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,24 @@
import os
import sys


from dotenv import load_dotenv

from astrapy import DataAPIClient
from astrapy.db import AstraDB
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_anthropic import ChatAnthropic
from langchain_astradb import AstraDBVectorStore

#from langchain.embeddings import OpenAIEmbeddings
#from langchain_community.embeddings import OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
#from langchain_community.embeddings import HuggingFaceEmbeddings
#from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_community.llms.openai import OpenAI
from langchain_openai import OpenAI
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.vectorstores.chroma import Chroma

from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import MessagesPlaceholder
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain

from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import HumanMessage, AIMessage
from langchain_anthropic import AnthropicLLM
from langchain_anthropic import ChatAnthropic
# from langchain.embeddings import OpenAIEmbeddings
# from langchain_community.embeddings import OpenAIEmbeddings
# from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_huggingface import HuggingFaceEmbeddings
# from langchain_community.llms.openai import OpenAI
from langchain_openai import OpenAI, OpenAIEmbeddings

"""
Script takes in args:
Script takes in args:
0) list of LLMs with which to retrieve e.g. ['gpt-4', 'gemini', 'mistral']
1) input string
2) chat history in following shape [
Expand All @@ -44,87 +28,85 @@
]
"""

def custom_history(entire_history:list, llm_name:str):

def custom_history(entire_history: list, llm_name: str):
chat_history = []
for msg in entire_history:
if 'user' in msg:
if 'user' in msg:
chat_history.extend([HumanMessage(content=msg['user'])])
if llm_name in msg:
chat_history.extend([AIMessage(content=msg[llm_name])])
return chat_history


def main():

if len(sys.argv) < 3:
print("""Error: Please provide:
print("""Error: Please provide:
1) [list of LLM models to use]
(['gpt-4', 'gemini', 'claude'])
2) 'input string'
3) [{chat history}] in the following shape:
3) [{chat history}] in the following shape:
[{'gpt-4': "Hello, how can I help you?"},
{'user': "What do prisons and plants have in common?"}
etc.]""")
# Arguments

# Arguments
llm_list = sys.argv[1]
llm_list = list(llm_list.replace('[', '').replace(']', '').replace("'", '').split(','))
if not llm_list:
llm_list = ['gpt-4']
#print(llm_list)
# print(llm_list)
input_string = sys.argv[2]
#print(input_string)
# print(input_string)
message_history = sys.argv[3]
#print(message_history)
# print(message_history)
message_history = message_history.split(';;')
#print(message_history)
# print(message_history)
message_history = [json.loads(substring.replace("'", '"')) for substring in message_history]
#print(message_history)
load_dotenv()
# to be put into seperate function in order to invoke LLMs seperately
# print(message_history)

load_dotenv()

# to be put into seperate function in order to invoke LLMs seperately
openai_api_key = os.environ.get('OPENAI_API_KEY')
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")

test_llm_list = ['gpt-4']
#llm_list = test_llm_list
test_history = [
{'gpt-4': "Hello, how can I help you?",
'gemini': "Hello, how can I help you?"},
{'user': "What do prisons and plants have in common?"},
{'gpt-4': "They both have cell walls.",
'gemini': "They have cell walls."},
]
# google_api_key = os.environ.get('GOOGLE_API_KEY')
# anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY')

# test_llm_list = ['gpt-4']
# llm_list = test_llm_list
# test_history = [
# {'gpt-4': 'Hello, how can I help you?', 'gemini': 'Hello, how can I help you?'},
# {'user': 'What do prisons and plants have in common?'},
# {'gpt-4': 'They both have cell walls.', 'gemini': 'They have cell walls.'},
# ]
# message_history = test_history
test_query = "Ah, true. Thanks. What else do they have in common?"

# test_query = 'Ah, true. Thanks. What else do they have in common?'
# test_query = "How many corners does a heptagon have?"
# input_string = test_query
# test_follow_up = "How does one call a polygon with two more corners?"
# AstraDB Section
ASTRA_DB_API_ENDPOINT = os.environ.get('ASTRA_DB_API_ENDPOINT')
ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN')
ASTRA_DB_NAMESPACE = 'test'
ASTRA_DB_COLLECTION = 'test_collection_2'

# AstraDB Section
astra_db_api_endpoint = os.environ.get('ASTRA_DB_API_ENDPOINT')
astra_db_application_token = os.environ.get('ASTRA_DB_APPLICATION_TOKEN')
astra_db_namespace = 'test'
astra_db_collection = 'test_collection_2'

# LangChain Docs: -------------------------
vstore = AstraDBVectorStore(
embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
collection_name=ASTRA_DB_COLLECTION,
api_endpoint=ASTRA_DB_API_ENDPOINT,
token=ASTRA_DB_APPLICATION_TOKEN,
namespace=ASTRA_DB_NAMESPACE,
embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
collection_name=astra_db_collection,
api_endpoint=astra_db_api_endpoint,
token=astra_db_application_token,
namespace=astra_db_namespace,
)
# ------------------------------------------

# For test purposes: -----------------------
# import bs4
# from langchain_chroma import Chroma
# from langchain_community.document_loaders import WebBaseLoader

# loader = WebBaseLoader(
# web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
# bs_kwargs=dict(
Expand All @@ -140,70 +122,71 @@ def main():
# vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())
# retriever = vectorstore.as_retriever()
# # test end ----------------------------------
retriever = vstore.as_retriever(search_kwargs={"k": 3})

retriever = vstore.as_retriever(search_kwargs={'k': 3})

contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""

qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise.\
{context}"""

answers = {}
for _llm in llm_list:
#print(_llm)
for _llm in llm_list:
# print(_llm)
chat_history = custom_history(message_history, _llm)
if _llm == 'gpt-4':
llm = OpenAI(temperature=0.2)
elif _llm == 'gemini':
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest")
llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro-latest')
elif _llm == 'claude':
llm = ChatAnthropic(model_name="claude-3-opus-20240229")
llm = ChatAnthropic(model_name='claude-3-opus-20240229')

print(chat_history)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
('system', contextualize_q_system_prompt),
MessagesPlaceholder('chat_history'),
('human', '{input}'),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)

qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
('system', qa_system_prompt),
MessagesPlaceholder('chat_history'),
('human', '{input}'),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
### Answer question ###
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
msg = rag_chain.invoke({"input": input_string, "chat_history": chat_history})
msg = rag_chain.invoke({'input': input_string, 'chat_history': chat_history})
answers[_llm] = msg['answer']
print(msg['answer'])
#print(answers)
# print(answers)

# chat_history.extend([HumanMessage(content=input_string), AIMessage(content=msg_1["answer"])])
# print(msg_1['input'])
# print(msg_1['answer'])
# print(chat_history)
# msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history})
# chat_history.extend([HumanMessage(content=test_follow_up), AIMessage(content=msg_2["answer"])])
# msg_2 = rag_chain.invoke({"input": test_follow_up, "chat_history": chat_history})
# chat_history.extend([HumanMessage(content=test_follow_up),
# AIMessage(content=msg_2["answer"])])
# print(msg_2['input'])
# print(msg_2['answer'])
# print(chat_history)
return answers

if __name__ == "__main__":

if __name__ == '__main__':
main()

0 comments on commit 1f8970b

Please sign in to comment.