Skip to content

Commit

Permalink
some revision and refactor of the interface
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Jun 24, 2024
1 parent 3b9ffd5 commit d428544
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 115 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ grobid_tei_xml==0.1.3
tqdm==4.66.2
pyyaml==6.0.1
pytest==8.1.1
streamlit==1.33.0
streamlit==1.36.0
lxml
Beautifulsoup4
python-dotenv
Expand Down
177 changes: 63 additions & 114 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral'
}

DISABLE_MEMORY = ['zephyr-7b-beta']

if 'rqa' not in st.session_state:
st.session_state['rqa'] = {}

Expand Down Expand Up @@ -108,36 +106,6 @@
}
)

css_modify_left_column = '''
<style>
[data-testid="stHorizontalBlock"] > div:nth-child(1) {
overflow: hidden;
background-color: red;
height: 70vh;
}
</style>
'''
css_modify_right_column = '''
<style>
[data-testid="stHorizontalBlock"]> div:first-child {
background-color: red;
position: fixed
height: 70vh;
}
</style>
'''
css_disable_scrolling_container = '''
<style>
[data-testid="ScrollToBottomContainer"] {
overflow: hidden;
}
</style>
'''


# st.markdown(css_lock_column_fixed, unsafe_allow_html=True)
# st.markdown(css2, unsafe_allow_html=True)


def new_file():
st.session_state['loaded_embeddings'] = None
Expand Down Expand Up @@ -188,7 +156,7 @@ def init_qa(model, embeddings_name=None, api_key=None):
)
embeddings = HuggingFaceEmbeddings(
model_name=OPEN_EMBEDDINGS[embeddings_name])
st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
# st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
else:
st.error("The model was not loaded properly. Try reloading. ")
st.stop()
Expand Down Expand Up @@ -233,23 +201,27 @@ def get_file_hash(fname):
return hash_md5.hexdigest()


def play_old_messages():
def play_old_messages(container):
if st.session_state['messages']:
for message in st.session_state['messages']:
if message['role'] == 'user':
with st.chat_message("user"):
st.markdown(message['content'])
container.chat_message("user").markdown(message['content'])
elif message['role'] == 'assistant':
with st.chat_message("assistant"):
if mode == "LLM":
st.markdown(message['content'], unsafe_allow_html=True)
else:
st.write(message['content'])
if mode == "LLM":
container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True)
else:
container.chat_message("assistant").write(message['content'])


# is_api_key_provided = st.session_state['api_key']

with st.sidebar:
st.title("📝 Scientific Document Insights Q/A")
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
st.markdown(
":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")

st.divider()
st.session_state['model'] = model = st.selectbox(
"Model:",
options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
Expand Down Expand Up @@ -305,22 +277,18 @@ def play_old_messages():
# else:
# is_api_key_provided = st.session_state['api_key']

st.button(
'Reset chat memory.',
key="reset-memory-button",
on_click=clear_memory,
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
# st.button(
# 'Reset chat memory.',
# key="reset-memory-button",
# on_click=clear_memory,
# help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
# disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)

left_column, right_column = st.columns([1, 1])
right_column = right_column.container(height=600, border=False)
left_column = left_column.container(height=600, border=False)

with right_column:
st.title("📝 Scientific Document Insights Q/A")
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")

st.markdown(
":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")

uploaded_file = st.file_uploader(
"Upload an article",
type=("pdf", "txt"),
Expand All @@ -330,11 +298,14 @@ def play_old_messages():
help="The full-text is extracted using Grobid."
)

question = st.chat_input(
"Ask something about the article",
# placeholder="Can you give me a short summary?",
disabled=not uploaded_file
)
placeholder = st.empty()
messages = st.container(height=300, border=False)

question = st.chat_input(
"Ask something about the article",
# placeholder="Can you give me a short summary?",
disabled=not uploaded_file
)

query_modes = {
"llm": "LLM Q/A",
Expand All @@ -355,6 +326,10 @@ def play_old_messages():
"relevant paragraphs to the question in the paper. "
"Question coefficient attempt to estimate how effective the question will be answered."
)
st.session_state['ner_processing'] = st.checkbox(
"Identify materials and properties.",
help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.'
)

# Add a checkbox for showing annotations
# st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
Expand All @@ -372,11 +347,6 @@ def play_old_messages():
help="Number of chunks to consider when answering a question",
disabled=not uploaded_file)

st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.")
st.markdown(
'The LLM responses undergo post-processing to extract <span style="color:orange">physical quantities, measurements</span>, and <span style="color:green">materials</span> mentions.',
unsafe_allow_html=True)

st.divider()

st.header("Documentation")
Expand All @@ -403,7 +373,7 @@ def play_old_messages():
st.error("Before uploading a document, you must enter the API key. ")
st.stop()

with right_column:
with left_column:
with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
binary = uploaded_file.getvalue()
tmp_file = NamedTemporaryFile()
Expand All @@ -416,8 +386,6 @@ def play_old_messages():
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []

# timestamp = datetime.utcnow()


def rgb_to_hex(rgb):
return "#{:02x}{:02x}{:02x}".format(*rgb)
Expand All @@ -439,41 +407,21 @@ def generate_color_gradient(num_elements):


with right_column:
# css = '''
# <style>
# [data-testid="column"] {
# overflow: auto;
# height: 70vh;
# }
# </style>
# '''
# st.markdown(css, unsafe_allow_html=True)

# st.markdown(
# """
# <script>
# document.querySelectorAll('[data-testid="column"]').scrollIntoView({behavior: "smooth"});
# </script>
# """,
# unsafe_allow_html=True,
# )

if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
with messages.chat_message(message["role"]):
if message['mode'] == "llm":
st.markdown(message["content"], unsafe_allow_html=True)
messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
elif message['mode'] == "embeddings":
st.write(message["content"])
messages.chat_message(message["role"]).write(message["content"])
if message['mode'] == "question_coefficient":
st.markdown(message["content"], unsafe_allow_html=True)
messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
if model not in st.session_state['rqa']:
st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
st.stop()

with st.chat_message("user"):
st.markdown(question)
st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
messages.chat_message("user").markdown(question)
st.session_state.messages.append({"role": "user", "mode": mode, "content": question})

text_response = None
if mode == "embeddings":
Expand All @@ -484,12 +432,13 @@ def generate_color_gradient(num_elements):
context_size=context_size
)
elif mode == "llm":
with st.spinner("Generating LLM response..."):
_, text_response, coordinates = st.session_state['rqa'][model].query_document(
question,
st.session_state.doc_id,
context_size=context_size
)
with placeholder:
with st.spinner("Generating LLM response..."):
_, text_response, coordinates = st.session_state['rqa'][model].query_document(
question,
st.session_state.doc_id,
context_size=context_size
)

elif mode == "question_coefficient":
with st.spinner("Estimate question/context relevancy..."):
Expand All @@ -511,28 +460,28 @@ def generate_color_gradient(num_elements):
if not text_response:
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")

with st.chat_message("assistant"):
if mode == "llm":
if st.session_state['ner_processing']:
with st.spinner("Processing NER on LLM response..."):
entities = gqa.process_single_text(text_response)
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
text_response = decorated_text
st.markdown(text_response, unsafe_allow_html=True)
else:
st.write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
if mode == "llm":
if st.session_state['ner_processing']:
with st.spinner("Processing NER on LLM response..."):
entities = gqa.process_single_text(text_response)
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
text_response = decorated_text
messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True)
else:
messages.chat_message("assistant").write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})

elif st.session_state.loaded_embeddings and st.session_state.doc_id:
play_old_messages()
play_old_messages(messages)

with left_column:
if st.session_state['binary']:
pdf_viewer(
input=st.session_state['binary'],
annotation_outline_size=1,
annotation_outline_size=2,
annotations=st.session_state['annotations'],
render_text=True
render_text=True,
height=700
)

0 comments on commit d428544

Please sign in to comment.