-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
89 lines (77 loc) · 3.37 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import os
from services import bedrock_agent_runtime
import streamlit as st
import uuid
# Get config from environment variables
agent_id = os.environ.get("BEDROCK_AGENT_ID")
agent_alias_id = os.environ.get("BEDROCK_AGENT_ALIAS_ID", "TSTALIASID") # TSTALIASID is the default test alias ID
ui_title = os.environ.get("BEDROCK_AGENT_TEST_UI_TITLE", "Agents for Amazon Bedrock Test UI")
ui_icon = os.environ.get("BEDROCK_AGENT_TEST_UI_ICON")
def init_state():
st.session_state.session_id = str(uuid.uuid4())
st.session_state.messages = []
st.session_state.trace = {}
# General page configuration and initialization
st.set_page_config(page_title=ui_title, page_icon=ui_icon, layout="wide")
st.title(ui_title)
if len(st.session_state.items()) == 0:
init_state()
# Sidebar button to reset session state
with st.sidebar:
if st.button("Reset Session"):
init_state()
# Messages in the conversation
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# Chat input that invokes the agent
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
placeholder = st.empty()
placeholder.markdown("...")
response = bedrock_agent_runtime.invoke_agent(
agent_id,
agent_alias_id,
st.session_state.session_id,
prompt
)
placeholder.markdown(response["output_text"])
st.session_state.messages.append({"role": "assistant", "content": response["output_text"]})
st.session_state.trace = response["trace"]
trace_type_headers = {
"preProcessingTrace": "Pre-Processing",
"orchestrationTrace": "Orchestration",
"postProcessingTrace": "Post-Processing"
}
trace_info_types = ["invocationInput", "modelInvocationInput", "modelInvocationOutput", "observation", "rationale"]
# Sidebar section for trace
with st.sidebar:
st.title("Trace")
# Show each trace types in separate sections
for trace_type in trace_type_headers:
st.subheader(trace_type_headers[trace_type])
# Organize traces by step similar to how it is shown in the Bedrock console
if trace_type in st.session_state.trace:
trace_steps = {}
for trace in st.session_state.trace[trace_type]:
# Each trace type and step may have different information for the end-to-end flow
for trace_info_type in trace_info_types:
if trace_info_type in trace:
trace_id = trace[trace_info_type]["traceId"]
if trace_id not in trace_steps:
trace_steps[trace_id] = [trace]
else:
trace_steps[trace_id].append(trace)
break
# Show trace steps in JSON similar to the Bedrock console
for step_num, trace_id in enumerate(trace_steps.keys(), start=1):
with st.expander("Trace Step " + str(step_num), expanded=False):
for trace in trace_steps[trace_id]:
trace_str = json.dumps(trace, indent=2)
st.code(trace_str, language="json", line_numbers=trace_str.count("\n"))
else:
st.text("None")