From 5bc6781aff9532a362b147a68c2997baf28700e5 Mon Sep 17 00:00:00 2001 From: biswaroop1547 Date: Wed, 25 Oct 2023 23:51:03 +0530 Subject: [PATCH 1/4] update: chatml prompt template as string to maintain configurability + generic stitch_prompt function --- cht-llama-cpp/main.py | 6 ++++- cht-llama-cpp/models.py | 58 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/cht-llama-cpp/main.py b/cht-llama-cpp/main.py index ef00449..2059423 100644 --- a/cht-llama-cpp/main.py +++ b/cht-llama-cpp/main.py @@ -11,6 +11,10 @@ load_dotenv() MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'mistral-7b-instruct-v0.1.Q5_0')}.gguf" +# Mistral gguf follows ChatML syntax +# https://github.com/openai/openai-python/blob/main/chatml.md +PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH) @@ -29,7 +33,7 @@ def create_start_app_handler(app: FastAPI): def start_app() -> None: from models import LLaMACPPBasedModel - LLaMACPPBasedModel.get_model(MODEL_PATH) + LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING) return start_app diff --git a/cht-llama-cpp/models.py b/cht-llama-cpp/models.py index 64a62eb..756e0c2 100644 --- a/cht-llama-cpp/models.py +++ b/cht-llama-cpp/models.py @@ -1,12 +1,26 @@ +import json import multiprocessing +from typing import Any, Dict, List -from llama_cpp import Llama +from llama_cpp import Llama, llama_chat_format, llama_types DEFAULT_N_THREADS = max(multiprocessing.cpu_count() // 2, 1) +@llama_chat_format.register_chat_format("chatml") +def initiate_chatml_prompt_template( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> llama_chat_format.ChatFormatterResponse: + # Until https://github.com/abetlen/llama-cpp-python/issues/717 supports ChatML. + + _prompt = LLaMACPPBasedModel.stitch_prompt(messages, LLaMACPPBasedModel.PROMPT_TEMPLATE) + return llama_chat_format.ChatFormatterResponse(prompt=_prompt) + + class LLaMACPPBasedModel(object): model = None + PROMPT_TEMPLATE = {} @classmethod def tokenize(cls, prompt): @@ -43,7 +57,7 @@ def generate( stop = [] messages = cls.reduce_number_of_messages(messages[::-1], max_tokens)[::-1] cls.model.n_threads = n_threads - return cls.model.create_chat_completion( + cht_resp = cls.model.create_chat_completion( messages, temperature=temperature, top_p=top_p, @@ -51,14 +65,50 @@ def generate( stop=stop, max_tokens=max_tokens, ) + if not stream and cls.PROMPT_TEMPLATE.get("template_format") == "chatml": + cht_resp["choices"][0]["message"]["content"] = ( + cht_resp["choices"][0]["message"]["content"].split("\n<|im_end|>")[0].strip() + ) + + # TODO: handle postprocessing for streaming responses + + return cht_resp @classmethod - def get_model(cls, model_path): + def get_model(cls, model_path, prompt_template_jsonstr): + chat_format = "llama-2" + if "mistral" in model_path: + cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr) + chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml") if cls.model is None: - cls.model = Llama(model_path) + cls.model = Llama(model_path, chat_format=chat_format) return cls.model @classmethod def embeddings(cls, text): return cls.model.create_embedding(text) + + @staticmethod + def stitch_prompt(messages: list, prompt_template: Dict[str, str]) -> str: + system_prompt_template = prompt_template["system_prompt_template"] + default_system_text = prompt_template["default_system_text"] + user_prompt_template = prompt_template["user_prompt_template"] + assistant_prompt_template = prompt_template["assistant_prompt_template"] + request_assistant_response_token = prompt_template.get("request_assistant_response_token", "") + + system_prompt, chat_prompt = "", "" + for message in messages: + role = message["role"] + content = message["content"] + if role == "system": + system_prompt = system_prompt_template.format(content) + elif role == "user": + chat_prompt += user_prompt_template.format(content) + elif role == "assistant": + chat_prompt += assistant_prompt_template.format(content) + + if not system_prompt: + system_prompt = system_prompt_template.format(default_system_text) + + return system_prompt + chat_prompt + request_assistant_response_token From fe09713ca50415973cd4a6bc610806b4194c14e0 Mon Sep 17 00:00:00 2001 From: biswaroop1547 Date: Wed, 25 Oct 2023 23:54:38 +0530 Subject: [PATCH 2/4] add: test for chatml --- cht-llama-cpp/tests/test_views.py | 34 ++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/cht-llama-cpp/tests/test_views.py b/cht-llama-cpp/tests/test_views.py index cac9504..262e13b 100644 --- a/cht-llama-cpp/tests/test_views.py +++ b/cht-llama-cpp/tests/test_views.py @@ -1,5 +1,8 @@ +import json + from fastapi.testclient import TestClient -from main import get_application +from main import PROMPT_TEMPLATE_STRING, get_application +from models import LLaMACPPBasedModel def test_chat_llama_cpp() -> None: @@ -24,3 +27,32 @@ def test_chat_llama_cpp() -> None: }, ) assert response.status_code == 200 + + +def test_chatml_stitch_prompt(): + messages = [ + {"role": "user", "content": "Why should we run ML models on premise?"}, + { + "role": "assistant", + "content": "There are several reasons why an organization might choose to run machine learning (ML) models on-premise:\n\n1. Security and privacy concerns: Running ML models on-premise allows organizations to", # noqa + }, + ] + prompt_template = json.loads(PROMPT_TEMPLATE_STRING) + assert prompt_template["template_format"] == "chatml" + result = LLaMACPPBasedModel.stitch_prompt(messages, prompt_template=prompt_template) + assert ( + result + == """<|im_start|>system +You are an helpful AI assistant. +<|im_end|> +<|im_start|>user +Why should we run ML models on premise? +<|im_end|> +<|im_start|>assistant +There are several reasons why an organization might choose to run machine learning (ML) models on-premise: + +1. Security and privacy concerns: Running ML models on-premise allows organizations to +<|im_end|> +<|im_start|>assistant +""" + ) From d4f2e2e414733cb304519782688a2d3c71ccaa53 Mon Sep 17 00:00:00 2001 From: biswaroop1547 Date: Wed, 25 Oct 2023 23:55:14 +0530 Subject: [PATCH 3/4] bump version --- cht-llama-cpp/build-aarch64-apple-darwin.sh | 2 +- cht-llama-cpp/build.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cht-llama-cpp/build-aarch64-apple-darwin.sh b/cht-llama-cpp/build-aarch64-apple-darwin.sh index f65fea9..2fc0111 100755 --- a/cht-llama-cpp/build-aarch64-apple-darwin.sh +++ b/cht-llama-cpp/build-aarch64-apple-darwin.sh @@ -1,7 +1,7 @@ #!/bin/bash set -e -export VERSION=1.1.1 +export VERSION=1.1.2 test -f venv/bin/activate || python -m venv venv source venv/bin/activate diff --git a/cht-llama-cpp/build.sh b/cht-llama-cpp/build.sh index b92a13b..8dd8420 100755 --- a/cht-llama-cpp/build.sh +++ b/cht-llama-cpp/build.sh @@ -1,6 +1,6 @@ #!/bin/bash set -e -export VERSION=1.1.0 +export VERSION=1.1.2 source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh" build_cpu ghcr.io/premai-io/chat-mistral-7b-instruct-q5 mistral-7b-instruct-v0.1.Q5_0 --build-arg="MODEL_ID=mistral-7b-instruct-v0.1.Q5_0" --build-arg="MODEL_DOWNLOAD_URL=https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q5_0.gguf" ${@:1} From 44c2a04ea0eb76510cdfcd34678e496fc9437867 Mon Sep 17 00:00:00 2001 From: Biswaroop Bhattacharjee Date: Thu, 26 Oct 2023 12:43:57 +0530 Subject: [PATCH 4/4] update: todo str from review comment Co-authored-by: Casper da Costa-Luis --- cht-llama-cpp/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cht-llama-cpp/models.py b/cht-llama-cpp/models.py index 756e0c2..2aec3cb 100644 --- a/cht-llama-cpp/models.py +++ b/cht-llama-cpp/models.py @@ -12,7 +12,7 @@ def initiate_chatml_prompt_template( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> llama_chat_format.ChatFormatterResponse: - # Until https://github.com/abetlen/llama-cpp-python/issues/717 supports ChatML. + # TODO: drop when https://github.com/abetlen/llama-cpp-python/issues/717 supports ChatML _prompt = LLaMACPPBasedModel.stitch_prompt(messages, LLaMACPPBasedModel.PROMPT_TEMPLATE) return llama_chat_format.ChatFormatterResponse(prompt=_prompt)