-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
73 lines (61 loc) · 2.11 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from config import settings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
def get_model():
"""
Initializes and returns the ChatOpenAI model.
Returns:
ChatOpenAI: The initialized ChatOpenAI model.
"""
return ChatOpenAI(
base_url=settings.BASE_URL,
temperature=0,
api_key=settings.API_KEY,
model_name=settings.MODEL
)
def get_prompt():
"""
Creates and returns a chat prompt template.
Returns:
ChatPromptTemplate: The constructed chat prompt template.
"""
return ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know or that you cannot answer. Use three sentences maximum and keep the answer concise.""",
),
MessagesPlaceholder(variable_name="chat_history"),
("user", """Context:
---
{context}
---
{question}"""),
]
)
def get_rag_chain(retriever, memory_buffer):
"""
Constructs and returns a RAG (Retrieval-Augmented Generation) chain.
Args:
retriever: The retriever object for fetching relevant documents.
memory_buffer: The memory buffer for storing chat history.
Returns:
RunnablePassthrough: The constructed RAG chain.
"""
model = get_model()
prompt = get_prompt()
memory = RunnablePassthrough.assign(chat_history=RunnableLambda(memory_buffer.load_memory_variables) | itemgetter("chat_history"))
return (
{
"context": retriever,
"question": RunnablePassthrough(),
}
| memory
| prompt
| model
| StrOutputParser()
)