-
Notifications
You must be signed in to change notification settings - Fork 0
/
2-rag-langchain.py
94 lines (75 loc) · 2.56 KB
/
2-rag-langchain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#####################
# Imports
#####################
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
#####################
# Constants
#####################
BOOK_PATH = "books/back-to-the-future-script.txt"
BOOK_NAME = "Back to the Future"
COLLECTION_NAME = BOOK_NAME.lower().replace(" ", "-")
VECTOR_STORE = "./vector-store-langchain"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
#####################
# ChromaDB setup
#####################
# if VECTOR_STORE does not exist, create it
if not os.path.exists(VECTOR_STORE):
f = open(BOOK_PATH)
text = f.read()
splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=80)
documents = splitter.create_documents([text])
vector_store = Chroma.from_documents(
documents=documents,
embedding=OpenAIEmbeddings(),
persist_directory=VECTOR_STORE,
)
else:
vector_store = Chroma(
persist_directory=VECTOR_STORE,
embedding_function=OpenAIEmbeddings(),
)
retriever = vector_store.as_retriever()
#####################
# Langchain prompting
#####################
prompt = ChatPromptTemplate.from_template(
"""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {context}
Answer: """
)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
llm = ChatOpenAI(model="gpt-4-turbo")
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
response = rag_chain.invoke(
"What is the name of the main character in Back to the Future?"
)
print(response)
# Alternative way to define the rag_chain using itemgetter
# https://python.langchain.com/docs/expression_language/primitives/parallel/#using-itemgetter-as-shorthand
# rag_chain = (
# {
# "context": itemgetter("question") | retriever,
# "question": itemgetter("question"),
# }
# | prompt
# | llm
# | StrOutputParser()
# )
# response = rag_chain.invoke(
# {"question": "What is the name of the main character in Back to the Future?"}
# )