-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwebapp.py
157 lines (125 loc) Β· 5.5 KB
/
webapp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
A simple web application to implement a chatbot. This app uses Streamlit
for the UI and the Python requests package to talk to an API endpoint that
implements text generation and Retrieval Augmented Generation (RAG) using LLMs
and Amazon OpenSearch as the vector database.
"""
import boto3
import streamlit as st
import requests as req
from typing import List, Tuple, Dict
from qna import ask
# utility functions
def get_cfn_outputs(stackname: str) -> List:
cfn = boto3.client('cloudformation')
outputs = {}
for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:
outputs[output['OutputKey']] = output['OutputValue']
return outputs
# global constants
STREAMLIT_SESSION_VARS: List[Tuple] = [("generated", []), ("past", []), ("input", ""), ("stored_session", [])]
HTTP_OK: int = 200
# two options for the chatbot, 1) get answer directly from the LLM
# 2) use RAG (find documents similar to the user query and then provide
# those as context to the LLM).
MODE_RAG: str = 'RAG'
MODE_TEXT2TEXT: str = 'Text Generation'
MODE_VALUES: List[str] = [MODE_RAG, MODE_TEXT2TEXT]
# Currently we use the flan-t5-xxl for text generation
# and gpt-j-6b for embeddings but in future we could support more
TEXT2TEXT_MODEL_LIST: List[str] = ["Claude V2"]
EMBEDDINGS_MODEL_LIST: List[str] = ["sbert"]
# if running this app on a compute environment that has
# IAM cloudformation::DescribeStacks access read the
# stack outputs to get the name of the LLM endpoint
CFN_ACCESS = False
if CFN_ACCESS is True:
CFN_STACK_NAME: str = "llm-apps-blog-rag"
outputs = get_cfn_outputs(CFN_STACK_NAME)
else:
# create an outputs dictionary with keys of interest
# the key value would need to be edited manually before
# running this app
outputs: Dict = {}
# REPLACE __API_GW_ENDPOINT__ WITH ACTUAL API GW ENDPOINT URL
outputs["LLMAppAPIEndpoint"] = "__API_GW_ENDPOINT__"
# API endpoint
# this is retrieved from the cloud formation template that was
# used to create this solution
api: str = outputs.get("LLMAppAPIEndpoint")
# api_rag_ep: str = f"{api}/api/v1/llm/rag"
# api_text2text_ep: str = f"{api}/api/v1/llm/text2text"
api_rag_ep=""
api_text2text_ep="https://qb2iyyqvmh.execute-api.us-east-1.amazonaws.com/dev/ask"
print(f"api_rag_ep={api_rag_ep}\napi_text2text_ep={api_text2text_ep}")
####################
# Streamlit code
####################
# Page title
st.set_page_config(page_title='Virtual assistant for knowledge base π©βπ»', layout='wide')
# keep track of conversations by using streamlit_session
_ = [st.session_state.setdefault(k, v) for k,v in STREAMLIT_SESSION_VARS]
# Define function to get user input
def get_user_input() -> str:
"""
Returns the text entered by the user
"""
print(st.session_state)
input_text = st.text_input("You: ",
st.session_state["input"],
key="input",
placeholder="Ask me a question and I will consult the knowledge base to answer...",
label_visibility='hidden')
return input_text
# sidebar with options
with st.sidebar.expander("βοΈ", expanded=True):
text2text_model = st.selectbox(label='Text2Text Model', options=TEXT2TEXT_MODEL_LIST)
embeddings_model = st.selectbox(label='Embeddings Model', options=EMBEDDINGS_MODEL_LIST)
mode = st.selectbox(label='Mode', options=MODE_VALUES)
# streamlit app layout sidebar + main panel
# the main panel has a title, a sub header and user input textbox
# and a text area for response and history
st.title("π©βπ» Virtual Assistant")
st.subheader(f" Powered by :blue[{TEXT2TEXT_MODEL_LIST[0]}] for text generation and :blue[{EMBEDDINGS_MODEL_LIST[0]}] for embeddings")
user_input = ""
# get user input
user_input: str = get_user_input()
# based on the selected mode type call the appropriate API endpoint
if user_input:
# headers for request and response encoding, same for both endpoints
headers: Dict = {"accept": "application/json", "Content-Type": "application/json"}
output: str = None
if mode == MODE_TEXT2TEXT:
# data = {"q": user_input}
# resp = req.post(api_text2text_ep, headers=headers, json=data)
# if resp.status_code != HTTP_OK:
# output = resp.text
# else:
# output = resp.json()['answer'][0]
output = ask(user_input)
elif mode == MODE_RAG:
# data = {"q": user_input, "verbose": True}
# resp = req.post(api_rag_ep, headers=headers, json=data)
# if resp.status_code != HTTP_OK:
# output = resp.text
# else:
# resp = resp.json()
# sources = [d['metadata']['source'] for d in resp['docs']]
# output = f"{resp['answer']} \n \n Sources: {sources}"
output = "Response here 2"
else:
print("error")
output = f"unhandled mode value={mode}"
st.session_state.past.append(user_input)
st.session_state.generated.append(output)
# download the chat history
download_str: List = []
with st.expander("Conversation", expanded=True):
for i in range(len(st.session_state['generated'])-1, -1, -1):
st.info(st.session_state["past"][i],icon="β")
st.success(st.session_state["generated"][i], icon="π©βπ»")
download_str.append(st.session_state["past"][i])
download_str.append(st.session_state["generated"][i])
download_str = '\n'.join(download_str)
if download_str:
st.download_button('Download', download_str)