Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Sep 25, 2024
1 parent a2d2ac0 commit f8c90e6
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions inference/python/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,17 @@ def generate_llama2_response(prompt_input):
if 'hf_token' in st.session_state.keys():
st.success('HF token already provided!', icon='✅')
hf_token = st.session_state.hf_token
print(hf_token)
else:
hf_token = st.text_input('Enter your Hugging Face token:', type='password')
if not (hf_token.startswith('hf_') and len(hf_token)==37):
st.warning('Please enter valid credentials!', icon='⚠️')
st.warning('please enter a valid token', icon='⚠️')
else:
st.success('Proceed to finetuning your model!', icon='👉')
st.session_state.hf_token = hf_token

# PEFT model name
peft_model_name = st.text_input("Enter the PEFT model name:", help="The name of the PEFT model should start with the username associated with the provided HF token, followed by '/'ß. E.g. 'username/peft-base-uncased'")

# Dataset selection
dataset_option = st.radio("Choose dataset source:", ["Upload JSON", "Hugging Face Dataset"])

Expand All @@ -123,6 +125,18 @@ def generate_llama2_response(prompt_input):
else:
dataset_name = st.text_input("Enter Hugging Face dataset name:")

# Finetuning parameters
st.subheader("Finetuning parameters")
lora_rank = st.number_input("LoRA rank", min_value=2, max_value=64, value=16, step=2)
lora_alpha = st.number_input("LoRA alpha", min_value=2, max_value=64, value=16, step=2)
target_modules = st.multiselect("Target modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"], default=["down_proj"])
learning_rate = st.number_input("Learning rate", min_value=1e-6, max_value=1e-3, value=1e-5, step=1e-6)
optimizer_type = st.selectbox("Optimizer type", ["SGD", "Adam", "AdamW", "Adagrad", "Adadelta", "Adamax", "RMSprop"])
momentum = st.number_input("Momentum", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
weight_decay = st.number_input("Weight decay", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
nesterov = st.checkbox("Nesterov")
max_steps = st.number_input("Max steps", min_value=1000, max_value=100000, value=10000, step=1000)

# Start finetuning button
if st.button("Start Finetuning"):
if not hf_token:
Expand Down

0 comments on commit f8c90e6

Please sign in to comment.