Skip to content

Commit

Permalink
fixed temporary issues: in prompt manager
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravarthik27 committed Oct 17, 2024
1 parent e13b779 commit 9b6862c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion langtest/modelhandler/llm_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def load_model(cls, hub: str, path: str, *args, **kwargs) -> "PretrainedModelFor
ValueError: If the model is not found online or locally.
ConfigError: If there is an error in the model configuration.
"""
exclude_args = ["task", "device", "stream", "model_type"]
exclude_args = ["task", "device", "stream", "model_type", "chat_template"]

model_type = kwargs.get("model_type", None)

Expand Down
4 changes: 2 additions & 2 deletions langtest/modelhandler/transformers_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,15 +804,15 @@ def predict(self, text: Union[str, dict], prompt: dict, **kwargs) -> str:

if examples:
prompt["template"] = "".join(
f"{k.title()}: {{{k}}}" for k in text.keys()
f"{k.title()}:\n{{{k}}}\n" for k in text.keys()
)
prompt_template = SimplePromptTemplate(**prompt)
text = prompt_template.format(**text)
messages = [*examples, {"role": "user", "content": text}]
else:
messages = [{"role": "user", "content": text}]
output = self.model._generate([messages])
return output[0][0].get("content", "")
return output[0].strip()

else:
prompt_template = SimplePromptTemplate(**prompt)
Expand Down
15 changes: 13 additions & 2 deletions langtest/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@ def get_template(self):

temp = []
order_less = []
for field in self.__dict__:

sorted_fields = sorted(
self.__dict__.keys(), key=lambda x: self.__field_order.index(x.lower())
)

for field in sorted_fields:
if field in self.__field_order:
temp.append(f"{field.title()}: {{{field}}}")
else:
order_less.append(f"{field.title()}: {{{field}}}")

if order_less:
temp.extend(order_less)
return "\n" + "\n".join(temp)
Expand Down Expand Up @@ -194,7 +200,12 @@ def lm_studio_prompt(self):

# assistant role
temp_ai["role"] = "assistant"
temp_ai["content"] = example.ai.get_template.format(**example.ai.get_example)
temp_ai["content"] = (
example.ai.get_template.format(**example.ai.get_example)
.replace("Answer:", "")
.strip()
+ "\n\n"
)

messages.append(temp_user)
messages.append(temp_ai)
Expand Down

0 comments on commit 9b6862c

Please sign in to comment.