From f8c90e64ae070cbcb4fee81080f31a00a758284f Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Wed, 25 Sep 2024 17:18:07 +0000 Subject: [PATCH] update --- inference/python/streamlit/app.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/inference/python/streamlit/app.py b/inference/python/streamlit/app.py index 5ed56148c9..4d8633e167 100644 --- a/inference/python/streamlit/app.py +++ b/inference/python/streamlit/app.py @@ -103,15 +103,17 @@ def generate_llama2_response(prompt_input): if 'hf_token' in st.session_state.keys(): st.success('HF token already provided!', icon='✅') hf_token = st.session_state.hf_token - print(hf_token) else: hf_token = st.text_input('Enter your Hugging Face token:', type='password') if not (hf_token.startswith('hf_') and len(hf_token)==37): - st.warning('Please enter valid credentials!', icon='⚠️') + st.warning('please enter a valid token', icon='⚠️') else: st.success('Proceed to finetuning your model!', icon='👉') st.session_state.hf_token = hf_token + # PEFT model name + peft_model_name = st.text_input("Enter the PEFT model name:", help="The name of the PEFT model should start with the username associated with the provided HF token, followed by '/'ß. E.g. 'username/peft-base-uncased'") + # Dataset selection dataset_option = st.radio("Choose dataset source:", ["Upload JSON", "Hugging Face Dataset"]) @@ -123,6 +125,18 @@ def generate_llama2_response(prompt_input): else: dataset_name = st.text_input("Enter Hugging Face dataset name:") + # Finetuning parameters + st.subheader("Finetuning parameters") + lora_rank = st.number_input("LoRA rank", min_value=2, max_value=64, value=16, step=2) + lora_alpha = st.number_input("LoRA alpha", min_value=2, max_value=64, value=16, step=2) + target_modules = st.multiselect("Target modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"], default=["down_proj"]) + learning_rate = st.number_input("Learning rate", min_value=1e-6, max_value=1e-3, value=1e-5, step=1e-6) + optimizer_type = st.selectbox("Optimizer type", ["SGD", "Adam", "AdamW", "Adagrad", "Adadelta", "Adamax", "RMSprop"]) + momentum = st.number_input("Momentum", min_value=0.0, max_value=1.0, value=0.0, step=0.01) + weight_decay = st.number_input("Weight decay", min_value=0.0, max_value=1.0, value=0.0, step=0.01) + nesterov = st.checkbox("Nesterov") + max_steps = st.number_input("Max steps", min_value=1000, max_value=100000, value=10000, step=1000) + # Start finetuning button if st.button("Start Finetuning"): if not hf_token: