Skip to content

Commit

Permalink
fix: bs>1 bug for peft models (quic#164)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>
  • Loading branch information
irajagop authored Oct 30, 2024
1 parent 205f1d7 commit b33197f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def generate(

# Decode loop
for num_token in range(1, generation_config.max_new_tokens):
if stopping_criteria(torch.from_numpy(inputs["input_ids"]), torch.from_numpy(outputs["logits"])):
if stopping_criteria(torch.from_numpy(inputs["input_ids"]), torch.from_numpy(outputs["logits"])).all():
break

outputs = self.qpc_session.run(inputs)
Expand Down
18 changes: 10 additions & 8 deletions tests/peft/test_peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,29 +158,31 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con
qeff_model.set_adapter("invalid")


@pytest.mark.parametrize("batch_size", [1, 4], ids=["bs1", "bs4"])
@pytest.mark.parametrize("base_config,adapter_config", configs)
def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, tmp_path):
def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path):
_, lora_model = create_peft_model(base_config, adapter_config)
qeff_model = QEffAutoPeftModelForCausalLM(lora_model)
qeff_model.export(tmp_path)
start = perf_counter()
qeff_model.compile(prefill_seq_len=32, ctx_len=128)
qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
end = perf_counter()
compile_time_0 = end - start

qeff_model.generate(
input_ids=np.zeros((1, 32), dtype="int64"),
input_ids=np.zeros((batch_size, 32), dtype="int64"),
attention_mask=np.concatenate(
[
np.ones(10, dtype="int64"),
np.zeros(22, dtype="int64"),
]
).reshape(1, 32),
np.ones((batch_size, 10), dtype="int64"),
np.zeros((batch_size, 22), dtype="int64"),
],
axis=1,
),
max_new_tokens=10,
)

start = perf_counter()
qeff_model.compile(prefill_seq_len=32, ctx_len=128)
qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
end = perf_counter()
compile_time_1 = end - start
assert compile_time_1 < 0.01 * compile_time_0

0 comments on commit b33197f

Please sign in to comment.