Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

adds chatml prompt template as string to maintain configurability #124

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cht-llama-cpp/build-aarch64-apple-darwin.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
set -e

export VERSION=1.1.1
export VERSION=1.1.2

test -f venv/bin/activate || python -m venv venv
source venv/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion cht-llama-cpp/build.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
set -e
export VERSION=1.1.0
export VERSION=1.1.2
source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh"

build_cpu ghcr.io/premai-io/chat-mistral-7b-instruct-q5 mistral-7b-instruct-v0.1.Q5_0 --build-arg="MODEL_ID=mistral-7b-instruct-v0.1.Q5_0" --build-arg="MODEL_DOWNLOAD_URL=https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q5_0.gguf" ${@:1}
Expand Down
6 changes: 5 additions & 1 deletion cht-llama-cpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
load_dotenv()

MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'mistral-7b-instruct-v0.1.Q5_0')}.gguf"
# Mistral gguf follows ChatML syntax
# https://github.com/openai/openai-python/blob/main/chatml.md
PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH)
Expand All @@ -29,7 +33,7 @@ def create_start_app_handler(app: FastAPI):
def start_app() -> None:
from models import LLaMACPPBasedModel

LLaMACPPBasedModel.get_model(MODEL_PATH)
LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING)

return start_app

Expand Down
58 changes: 54 additions & 4 deletions cht-llama-cpp/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import json
import multiprocessing
from typing import Any, Dict, List

from llama_cpp import Llama
from llama_cpp import Llama, llama_chat_format, llama_types

DEFAULT_N_THREADS = max(multiprocessing.cpu_count() // 2, 1)


@llama_chat_format.register_chat_format("chatml")
def initiate_chatml_prompt_template(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> llama_chat_format.ChatFormatterResponse:
# TODO: drop when https://github.com/abetlen/llama-cpp-python/issues/717 supports ChatML

_prompt = LLaMACPPBasedModel.stitch_prompt(messages, LLaMACPPBasedModel.PROMPT_TEMPLATE)
return llama_chat_format.ChatFormatterResponse(prompt=_prompt)


class LLaMACPPBasedModel(object):
model = None
PROMPT_TEMPLATE = {}

@classmethod
def tokenize(cls, prompt):
Expand Down Expand Up @@ -43,22 +57,58 @@ def generate(
stop = []
messages = cls.reduce_number_of_messages(messages[::-1], max_tokens)[::-1]
cls.model.n_threads = n_threads
return cls.model.create_chat_completion(
cht_resp = cls.model.create_chat_completion(
messages,
temperature=temperature,
top_p=top_p,
stream=stream,
stop=stop,
max_tokens=max_tokens,
)
if not stream and cls.PROMPT_TEMPLATE.get("template_format") == "chatml":
cht_resp["choices"][0]["message"]["content"] = (
cht_resp["choices"][0]["message"]["content"].split("\n<|im_end|>")[0].strip()
)

# TODO: handle postprocessing for streaming responses
casperdcl marked this conversation as resolved.
Show resolved Hide resolved

return cht_resp

@classmethod
def get_model(cls, model_path):
def get_model(cls, model_path, prompt_template_jsonstr):
chat_format = "llama-2"
if "mistral" in model_path:
cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr)
chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml")
if cls.model is None:
cls.model = Llama(model_path)
cls.model = Llama(model_path, chat_format=chat_format)

return cls.model

@classmethod
def embeddings(cls, text):
return cls.model.create_embedding(text)

@staticmethod
def stitch_prompt(messages: list, prompt_template: Dict[str, str]) -> str:
system_prompt_template = prompt_template["system_prompt_template"]
default_system_text = prompt_template["default_system_text"]
user_prompt_template = prompt_template["user_prompt_template"]
assistant_prompt_template = prompt_template["assistant_prompt_template"]
request_assistant_response_token = prompt_template.get("request_assistant_response_token", "")

system_prompt, chat_prompt = "", ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
system_prompt = system_prompt_template.format(content)
elif role == "user":
chat_prompt += user_prompt_template.format(content)
elif role == "assistant":
chat_prompt += assistant_prompt_template.format(content)

if not system_prompt:
system_prompt = system_prompt_template.format(default_system_text)

return system_prompt + chat_prompt + request_assistant_response_token
34 changes: 33 additions & 1 deletion cht-llama-cpp/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json

from fastapi.testclient import TestClient
from main import get_application
from main import PROMPT_TEMPLATE_STRING, get_application
from models import LLaMACPPBasedModel


def test_chat_llama_cpp() -> None:
Expand All @@ -24,3 +27,32 @@ def test_chat_llama_cpp() -> None:
},
)
assert response.status_code == 200


def test_chatml_stitch_prompt():
messages = [
{"role": "user", "content": "Why should we run ML models on premise?"},
{
"role": "assistant",
"content": "There are several reasons why an organization might choose to run machine learning (ML) models on-premise:\n\n1. Security and privacy concerns: Running ML models on-premise allows organizations to", # noqa
},
]
prompt_template = json.loads(PROMPT_TEMPLATE_STRING)
assert prompt_template["template_format"] == "chatml"
result = LLaMACPPBasedModel.stitch_prompt(messages, prompt_template=prompt_template)
assert (
result
== """<|im_start|>system
You are an helpful AI assistant.
<|im_end|>
<|im_start|>user
Why should we run ML models on premise?
<|im_end|>
<|im_start|>assistant
There are several reasons why an organization might choose to run machine learning (ML) models on-premise:

1. Security and privacy concerns: Running ML models on-premise allows organizations to
<|im_end|>
<|im_start|>assistant
"""
)