Skip to content

Commit

Permalink
revert default model, use additional environ variable to default it
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Jan 15, 2024
1 parent 70e7048 commit aeb450e
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def play_old_messages():
st.session_state['model'] = model = st.selectbox(
"Model:",
options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
index=OPENAI_MODELS.index('gpt-3.5-turbo'),
index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
"zephyr-7b-beta") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else (
OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]),
placeholder="Select model",
help="Select the LLM model:",
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
Expand Down Expand Up @@ -313,8 +315,8 @@ def play_old_messages():
disabled=uploaded_file is not None)
if chunk_size == -1:
context_size = st.slider("Context size", 3, 20, value=10,
help="Number of paragraphs to consider when answering a question",
disabled=not uploaded_file)
help="Number of paragraphs to consider when answering a question",
disabled=not uploaded_file)
else:
context_size = st.slider("Context size", 3, 10, value=4,
help="Number of chunks to consider when answering a question",
Expand Down Expand Up @@ -363,17 +365,20 @@ def play_old_messages():

# timestamp = datetime.utcnow()


def rgb_to_hex(rgb):
return "#{:02x}{:02x}{:02x}".format(*rgb)


def generate_color_gradient(num_elements):
# Define warm and cold colors in RGB format
warm_color = (255, 165, 0) # Orange
cold_color = (0, 0, 255) # Blue
cold_color = (0, 0, 255) # Blue

# Generate a linear gradient of colors
color_gradient = [
rgb_to_hex(tuple(int(warm * (1 - i/num_elements) + cold * (i/num_elements)) for warm, cold in zip(warm_color, cold_color)))
rgb_to_hex(tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in
zip(warm_color, cold_color)))
for i in range(num_elements)
]

Expand Down Expand Up @@ -427,7 +432,7 @@ def generate_color_gradient(num_elements):
context_size=context_size)
annotations = [
GrobidAggregationProcessor.box_to_dict(coo) for coo in [c.split(",") for coord in
coordinates for c in coord]
coordinates for c in coord]
]
gradients = generate_color_gradient(len(annotations))
for i, color in enumerate(gradients):
Expand Down Expand Up @@ -465,6 +470,7 @@ def generate_color_gradient(num_elements):
with left_column:
if st.session_state['binary']:
if st.session_state['should_show_annotations']:
pdf_viewer(input=st.session_state['binary'], width=600, height=800, annotations=st.session_state['annotations'])
pdf_viewer(input=st.session_state['binary'], width=600, height=800,
annotations=st.session_state['annotations'])
else:
pdf_viewer(input=st.session_state['binary'], width=600, height=800)

0 comments on commit aeb450e

Please sign in to comment.