Skip to content

Commit

Permalink
Update file
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-doerr committed Oct 16, 2022
1 parent 709b42d commit 65c3cd4
Showing 1 changed file with 75 additions and 25 deletions.
100 changes: 75 additions & 25 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import time
import math

PROMPT_PREFIX = '''
Info: "With over 100 active members, TUM.ai is Germany's leading AI student initiative, located in Munich. 🎓
Expand Down Expand Up @@ -112,6 +113,12 @@

PROMPTS_LOG_CSV = 'propmts.csv'
RESPONSES_LOG_CSV = 'responses.csv'
CHARACTERS_TOKEN_RATIO_ESTIMATE = 4.66
PROMPT_FILE = 'prompt.txt'
with open(PROMPT_FILE, 'r') as f:
PROMPT_PREFIX = f.read()



if (('log' in st.experimental_get_query_params()) and st.experimental_get_query_params()['log'][0] == 'true'):
st.title('Showing log')
Expand All @@ -127,6 +134,32 @@
st.write('No log found')


def get_roughly_n_tokens_section(text, n_tokens):
text_split_empty_lines = text.split('\n\n')
num_target_chars = n_tokens * CHARACTERS_TOKEN_RATIO_ESTIMATE
line_num_last_added = 0
section_texts = []
while True:
out_text = ''
# for i, line in enumerate(text_split_empty_lines[i:]):
for i in range(line_num_last_added + 1, len(text_split_empty_lines)):
print("i:", i)
line = text_split_empty_lines[i]
if len(out_text) + len(line) > num_target_chars:
break

if i >= len(text_split_empty_lines) - 1:
break

out_text += line + '\n\n'
line_num_last_added = i

if i >= len(text_split_empty_lines) - 1:
break

section_texts.append(out_text)

return section_texts


def write_page_load_stats():
Expand Down Expand Up @@ -167,12 +200,12 @@ def get_num_prompts_last_x_min(mins):


MINUTES_TO_CONSIDER = 60
MAX_REQUESTS_PER_MINUTE = 120
MAX_REQUESTS_PER_MINUTE = 12

num_prompts_last_x_min = get_num_prompts_last_x_min(MINUTES_TO_CONSIDER)

print("num_prompts_last_x_min:", num_prompts_last_x_min)
if num_prompts_last_x_min >= MAX_REQUESTS_PER_MINUTE:
if num_prompts_last_x_min >= MAX_REQUESTS_PER_MINUTE * MINUTES_TO_CONSIDER:
st.info('Hit the rate limit, please try again in a few minutes.')
st.stop()

Expand All @@ -197,8 +230,11 @@ def initialize_openai_api():
st.stop()

log_prompt(user_input)
print('creating sections ...')
sections = get_roughly_n_tokens_section(PROMPT_PREFIX, 1000)
print('sections created')
print("sections:", sections)

prompt = PROMPT_PREFIX + user_input + PROMPT_POSTFIX
print("user_input:", user_input)

SUFFIX = '''"
Expand All @@ -212,27 +248,41 @@ def initialize_openai_api():
else:
suffix = None

response = openai.Completion.create(engine=MODEL, prompt=prompt, suffix=suffix, temperature=0.5, stream=True, stop='User: "', max_tokens=250)

completion_all = ''
response_text_field = st.empty()

while True:
next_response = next(response)
completion = next_response['choices'][0]['text']
if next_response['choices'][0]['finish_reason']:
if completion_all[-1] == '"':
completion_all = completion_all.strip()[:-1]

completion_all += completion
# print("completion_all:", completion_all)
# response_text_field.text(completion_all)
response_text_field.markdown(completion_all)
if next_response['choices'][0]['finish_reason']:
break

print("completion_all:", completion_all)

log_response(completion_all)
with st.spinner('Thinking about possible answers...'):
columns = st.columns(len(sections))
for section_num, section in enumerate(sections):
prompt_prefix = section
prompt = prompt_prefix + user_input + PROMPT_POSTFIX
response = openai.Completion.create(engine=MODEL, prompt=prompt, suffix=suffix, temperature=0.5, stream=True, stop='User: "', max_tokens=250, logprobs=1)

completion_all = ''
logprob_values = []


with columns[section_num]:
response_text_field = st.empty()
while True:
next_response = next(response)
print("next_response:", next_response)
completion = next_response['choices'][0]['text']
logprob_values.append(next_response['choices'][0]['logprobs']['token_logprobs'][0])
if next_response['choices'][0]['finish_reason']:
if completion_all[-1] == '"':
completion_all = completion_all.strip()[:-1]

completion_all += completion
# print("completion_all:", completion_all)
# response_text_field.text(completion_all)
response_text_field.markdown(completion_all)
if next_response['choices'][0]['finish_reason']:
break

print("completion_all:", completion_all)
logprob_avg = sum(logprob_values) / len(logprob_values)
# st.write(f'Average logprob: {logprob_avg}')
# st.write(f'Certainty: {math.exp(logprob_avg)}')
response_text_field.markdown(completion_all + f'\n\nCertainty: {math.exp(logprob_avg)}')

log_response(completion_all)


0 comments on commit 65c3cd4

Please sign in to comment.