-
Notifications
You must be signed in to change notification settings - Fork 290
/
main.py
167 lines (129 loc) · 5.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# 1: Import libraries
import streamlit as st
from groq import Groq
import json
from infinite_bookshelf.agents import (
generate_section,
generate_book_structure,
generate_book_title,
)
from infinite_bookshelf.inference import GenerationStatistics
from infinite_bookshelf.tools import create_markdown_file, create_pdf_file
from infinite_bookshelf.ui.components import (
render_groq_form,
display_statistics,
render_download_buttons,
)
from infinite_bookshelf.ui import Book, load_return_env, ensure_states
# 2: Initialize env variables and session states
GROQ_API_KEY = load_return_env(["GROQ_API_KEY"])["GROQ_API_KEY"]
states = {
"api_key": GROQ_API_KEY,
"button_disabled": False,
"button_text": "Generate",
"statistics_text": "",
"book_title": "",
}
if GROQ_API_KEY:
states["groq"] = (
Groq()
) # Define Groq provider if API key provided. Otherwise defined later after API key is provided.
ensure_states(states)
# 3: Define Streamlit page structure and functionality
st.write(
"""
# Infinite Bookshelf: Write full books using llama3.3 70b on Groq
"""
)
col1, col2 = st.columns([0.7,0.3])
with col1:
st.info(
"You are using a streamlined version. Try the new [advanced version](/advanced) in beta."
)
with col2:
st.image("assets/logo/powered-by-groq.svg", width=150)
def disable():
st.session_state.button_disabled = True
def enable():
st.session_state.button_disabled = False
def empty_st():
st.empty()
try:
if st.button("End Generation and Download Book"):
if "book" in st.session_state:
render_download_buttons(st.session_state.get("book"))
submitted, groq_input_key, topic_text, additional_instructions = render_groq_form(
on_submit=disable,
button_disabled=st.session_state.button_disabled,
button_text=st.session_state.button_text,
)
if submitted:
if len(topic_text) < 10:
raise ValueError("Book topic must be at least 10 characters long")
st.session_state.button_disabled = True
st.session_state.statistics_text = (
"Generating book title and structure in background...."
)
placeholder = st.empty()
display_statistics(
placeholder=placeholder, statistics_text=st.session_state.statistics_text
)
if not GROQ_API_KEY:
st.session_state.groq = Groq(api_key=groq_input_key)
# Step 1: Generate book structure using structure_writer agent
large_model_generation_statistics, book_structure = generate_book_structure(
prompt=topic_text,
additional_instructions=additional_instructions,
model="llama-3.3-70b-specdec",
groq_provider=st.session_state.groq,
)
# Step 2: Generate book title using title_writer agent
st.session_state.book_title = generate_book_title(
prompt=topic_text,
model="llama-3.3-70b-specdec",
groq_provider=st.session_state.groq,
)
st.write(f"## {st.session_state.book_title}")
total_generation_statistics = GenerationStatistics(model_name="llama-3.3-70b-specdec")
# Step 3: Generate book section content using section_writer agent
try:
book_structure_json = json.loads(book_structure)
book = Book(st.session_state.book_title, book_structure_json)
if "book" not in st.session_state:
st.session_state.book = book
# Print the book structure to the terminal to show structure
print(json.dumps(book_structure_json, indent=2))
st.session_state.book.display_structure()
def stream_section_content(sections):
for title, content in sections.items():
if isinstance(content, str):
content_stream = generate_section(
prompt=(title + ": " + content),
additional_instructions=additional_instructions,
model="llama-3.3-70b-specdec",
groq_provider=st.session_state.groq,
)
for chunk in content_stream:
# Check if GenerationStatistics data is returned instead of str tokens
chunk_data = chunk
if type(chunk_data) == GenerationStatistics:
total_generation_statistics.add(chunk_data)
st.session_state.statistics_text = str(
total_generation_statistics
)
display_statistics(
placeholder=placeholder,
statistics_text=st.session_state.statistics_text,
)
elif chunk != None:
st.session_state.book.update_content(title, chunk)
elif isinstance(content, dict):
stream_section_content(content)
stream_section_content(book_structure_json)
except json.JSONDecodeError:
st.error("Failed to decode the book structure. Please try again.")
except Exception as e:
st.session_state.button_disabled = False
st.error(e)
if st.button("Clear"):
st.rerun()