-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain1.py
196 lines (162 loc) · 6.25 KB
/
main1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from typing import List
from typing_extensions import TypedDict
import re
from pydantic import BaseModel, Field
# langchain related libraries
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
# langgraph related libraries
from langgraph.graph import StateGraph, START, END
class AgentState(TypedDict):
question: str
llm_output: str
documents: list[str]
cnt_retries: int
#Node — Question Scope Classifier
class QuestionScopeClass(BaseModel):
"""Scope of the question"""
score: str = Field(
description="Boolean value to check if question is about what, where or comparison. If yes -> 'Yes', else 'No'"
)
def question_intent_classifier(state: AgentState):
question = state["question"]
state['cnt_retries']=0
parser = JsonOutputParser(pydantic_object=QuestionScopeClass)
output_format=parser.get_format_instructions()
print(output_format)
system = """You are a question classifier. Check if the question is about one of the following topics:
1. definition
2. availability
3. comparison
If the question IS about these topics, respond with "Yes", otherwise respond with "No".
Format output as: `{output_format}`
"""
intent_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "User question: {question}"),
]
)
llm = openAI()
grader_llm = intent_prompt | llm | parser
result = grader_llm.invoke({"question": question, 'output_format': output_format})
print(f"to_proceed: {result['score']}")
state["on_topic"] = result['score']
return state
# router to enable conditional edges
def on_topic_router(state: AgentState):
print('ontopic router ... ')
on_topic = state["on_topic"]
if on_topic.lower() == "yes":
return "on_topic"
return "off_topic"
def grade_answer(state: AgentState):
answer= state['llm_output']
print('grading....')
pattern =r'do not know|sorry|apolog'
is_answer = 'Yes' if not re.search(pattern, answer.lower()) else 'No'
state['is_answer_ok'] = is_answer
print(f"answer grade: {is_answer}")
return state
def is_answer_router(state: AgentState):
print('grading router ... ')
is_answer = state["is_answer_ok"]
if state['cnt_retries'] >2: # max of 3 retries allowed (0 to 2)
return "hit_max_retries"
if is_answer.lower() == "yes":
return "is_answer"
return "is_not_answer"
def question_rephraser(state: AgentState):
print('rephrasing ...')
question = state['question']
print(f"retrying: {state['cnt_retries']+1}")
llm = openAI()
template = """
You are an expert in rephrasing English questions. \
You hav been tasked to rephrase question from the Retail and Supply Chain domain. \
While rephrasing, you may do the following:
1. Extract keywords from the original question
2. Expand or create abbreviations of the question as needed
3. Understand the intent of the question
4. Include the above information to generate a rephrased version of the original question.\
Do not output anything else apart from the rephrased question.\
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(
template=template,
)
chain = prompt | llm | StrOutputParser()
result = chain.invoke({"question": question})
# print(result)
state['question'] = result
state['cnt_retries'] +=1
return state
# retriever = get_retriever(config_vector_store)
# def retrieve_docs(state: AgentState):
# question = state["question"]
# documents = retriever.invoke(input=question)
# state["documents"] = documents
# print(f"cnt of retrieved docs: {len(documents)}")
# return state
def generate_answer(state: AgentState):
question = state['question']
context = [doc.page_content for doc in state['documents']]
print('generating answer ... ')
llm = openAI()
template = """
You are a Customer Support Chatbot aimed to answer the user's queries coming from Retail and Ecommerce industry.\
Keep the tone conversational and professional.\
Remember that abbreviations mentioned are related to these domains.\
Answer the question strictly based on the context provided.\
Avoid mentioning in the response that a context was referred.\
Avoid using words like 'certainly" and "it looks like" in the generated response.\
Do not output anything else apart from the answer.\
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(
template=template,
)
chain = prompt | llm | StrOutputParser()
result = chain.invoke({"question": question, "context": context})
print(f"result from generation: {result}")
state['llm_output'] = result
return state
def get_default_reply(state:AgentState):
print('get the default answer ...')
state['llm_output'] = 'I do not have an answer.'
return state
workflow = StateGraph(AgentState)
# Add the Nodes
workflow.add_node('intent_classifier', question_intent_classifier)
#workflow.add_node('retrieve_docs', retrieve_docs)
workflow.add_node('generate_answer', generate_answer)
workflow.add_node('grade_answer', grade_answer)
workflow.add_node('question_rephraser', question_rephraser)
workflow.add_node('default_reply', get_default_reply)
# Add the Edges including the Conditional Edges
workflow.add_edge('intent_classifier', START)
workflow.add_conditional_edges(
'intent_classifier', on_topic_router,
{
'on_topic': 'retrieve_docs',
'off_topic': 'default_reply'
}
)
workflow.add_edge('retrieve_docs', 'generate_answer')
workflow.add_edge('generate_answer', 'grade_answer')
workflow.add_conditional_edges(
'grade_answer', is_answer_router,
{
'is_answer':END,
'is_not_answer':'question_rephraser',
'hit_max_retries':'default_reply'
}
)
workflow.add_edge('question_rephraser', 'retrieve_docs')
workflow.add_edge('default_reply', END)
# compile the workflow
app = workflow.compile()
query = "Capital of India"
response = app.invoke(input={"question": query})
print(result['llm_output'])